6 Commits

Author SHA1 Message Date
oobabooga
776b7914bf Delete Open Assistant.txt 2023-04-04 12:55:55 -03:00
oobabooga
2944c6d204 Delete Alpaca.txt 2023-04-04 12:55:47 -03:00
oobabooga
cbaa231a0a Delete dummy.txt 2023-04-04 12:55:22 -03:00
oobabooga
065383ec67 Add files via upload 2023-04-04 12:55:05 -03:00
oobabooga
214dd6307e Create dummy.txt 2023-04-04 12:54:37 -03:00
oobabooga
a500061b08 Create script.py 2023-04-04 12:54:04 -03:00
26 changed files with 323 additions and 459 deletions

View File

@@ -26,7 +26,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* CPU mode * CPU mode
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen) * [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed) * [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming * API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model) * [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\*** * [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model) * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
@@ -62,7 +62,7 @@ Recommended if you have some experience with the command-line.
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide). On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
#### 0. Install Conda 0. Install Conda
https://docs.conda.io/en/latest/miniconda.html https://docs.conda.io/en/latest/miniconda.html
@@ -75,14 +75,14 @@ bash Miniconda3.sh
Source: https://educe-ubc.github.io/conda.html Source: https://educe-ubc.github.io/conda.html
#### 1. Create a new conda environment 1. Create a new conda environment
``` ```
conda create -n textgen python=3.10.9 conda create -n textgen python=3.10.9
conda activate textgen conda activate textgen
``` ```
#### 2. Install Pytorch 2. Install Pytorch
| System | GPU | Command | | System | GPU | Command |
|--------|---------|---------| |--------|---------|---------|
@@ -92,12 +92,10 @@ conda activate textgen
The up to date commands can be found here: https://pytorch.org/get-started/locally/. The up to date commands can be found here: https://pytorch.org/get-started/locally/.
#### 2.1 Special instructions MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
* AMD users: https://rentry.org/eq3hg
#### 3. Install the web UI 3. Install the web UI
``` ```
git clone https://github.com/oobabooga/text-generation-webui git clone https://github.com/oobabooga/text-generation-webui
@@ -177,6 +175,7 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | show this help message and exit | | `-h`, `--help` | show this help message and exit |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.| | `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--model-dir MODEL_DIR` | Path to directory with all the models | | `--model-dir MODEL_DIR` | Path to directory with all the models |

View File

@@ -36,7 +36,6 @@ async def run(context):
'early_stopping': False, 'early_stopping': False,
'seed': -1, 'seed': -1,
} }
payload = json.dumps([context, params])
session = random_hash() session = random_hash()
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -55,7 +54,22 @@ async def run(context):
"session_hash": session, "session_hash": session,
"fn_index": 12, "fn_index": 12,
"data": [ "data": [
payload context,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
] ]
})) }))
case "process_starts": case "process_starts":

View File

@@ -10,8 +10,6 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely. allowing you to use the API remotely.
''' '''
import json
import requests import requests
# Server address # Server address
@@ -40,11 +38,24 @@ params = {
# Input prompt # Input prompt
prompt = "What I would like to say is the following: " prompt = "What I would like to say is the following: "
payload = json.dumps([prompt, params])
response = requests.post(f"http://{server}:7860/run/textgen", json={ response = requests.post(f"http://{server}:7860/run/textgen", json={
"data": [ "data": [
payload prompt,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
] ]
}).json() }).json()

View File

@@ -1,16 +1,15 @@
name: "Chiharu Yamada"
context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology." context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
greeting: |- greeting: |-
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
example_dialogue: |- example_dialogue: |-
{{user}}: So how did you get into computer engineering? {{user}}: So how did you get into computer engineering?
{{char}}: I've always loved tinkering with technology since I was a kid. {{char}}: I've always loved tinkering with technology since I was a kid.
{{user}}: That's really impressive! {{user}}: That's really impressive!
{{char}}: *She chuckles bashfully* Thanks! {{char}}: *She chuckles bashfully* Thanks!
{{user}}: So what do you do when you're not working on computers? {{user}}: So what do you do when you're not working on computers?
{{char}}: I love exploring, going out with friends, watching movies, and playing video games. {{char}}: I love exploring, going out with friends, watching movies, and playing video games.
{{user}}: What's your favorite type of computer hardware to work with? {{user}}: What's your favorite type of computer hardware to work with?
{{char}}: Motherboards, they're like puzzles and the backbone of any system. {{char}}: Motherboards, they're like puzzles and the backbone of any system.
{{user}}: That sounds great! {{user}}: That sounds great!
{{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job. {{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.

View File

@@ -1,3 +0,0 @@
name: "### Response:"
your_name: "### Instruction:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

View File

@@ -1,3 +0,0 @@
name: "<|assistant|>"
your_name: "<|prompter|>"
end_of_turn: "<|endoftext|>"

View File

@@ -1,56 +0,0 @@
.chat {
margin-left: auto;
margin-right: auto;
max-width: 800px;
height: 66.67vh;
overflow-y: auto;
padding-right: 20px;
display: flex;
flex-direction: column-reverse;
}
.message {
display: grid;
grid-template-columns: 60px 1fr;
padding-bottom: 25px;
font-size: 15px;
font-family: Helvetica, Arial, sans-serif;
line-height: 1.428571429;
}
.text p {
margin-top: 5px;
}
.username {
display: none;
}
.message-body {}
.message-body p {
margin-bottom: 0 !important;
font-size: 15px !important;
line-height: 1.428571429 !important;
}
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
}
.message-body p em {
color: rgb(110, 110, 110) !important;
}
.assistant-message {
padding: 10px;
}
.user-message {
padding: 10px;
background-color: #f1f1f1;
}
.dark .user-message {
background-color: #ffffff1a;
}

View File

@@ -63,7 +63,3 @@ span.math.inline {
font-size: 27px; font-size: 27px;
vertical-align: baseline !important; vertical-align: baseline !important;
} }
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
flex-wrap: nowrap;
}

View File

@@ -40,27 +40,24 @@ class Handler(BaseHTTPRequestHandler):
prompt_lines.pop(0) prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines) prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
'num_beams': int(body.get('num_beams',1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
}
generator = generate_reply( generator = generate_reply(
prompt, question = prompt,
generate_params, max_new_tokens = int(body.get('max_length', 200)),
do_sample=bool(body.get('do_sample', True)),
temperature=float(body.get('temperature', 0.5)),
top_p=float(body.get('top_p', 1)),
typical_p=float(body.get('typical', 1)),
repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1,
top_k=int(body.get('top_k', 0)),
min_length=int(body.get('min_length', 0)),
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
num_beams=int(body.get('num_beams',1)),
penalty_alpha=float(body.get('penalty_alpha', 0)),
length_penalty=float(body.get('length_penalty', 1)),
early_stopping=bool(body.get('early_stopping', False)),
seed=int(body.get('seed', -1)),
stopping_strings=body.get('stopping_strings', []), stopping_strings=body.get('stopping_strings', []),
) )

View File

@@ -2,8 +2,9 @@ from pathlib import Path
import gradio as gr import gradio as gr
from modules.chat import load_character
from modules.html_generator import get_image_cache from modules.html_generator import get_image_cache
from modules.shared import gradio from modules.shared import gradio, settings
def generate_css(): def generate_css():
@@ -63,13 +64,22 @@ def generate_html():
for file in sorted(Path("characters").glob("*")): for file in sorted(Path("characters").glob("*")):
if file.suffix in [".json", ".yml", ".yaml"]: if file.suffix in [".json", ".yml", ".yaml"]:
character = file.stem character = file.stem
container_html = '<div class="character-container">' container_html = f'<div class="character-container">'
image_html = "<div class='placeholder'></div>" image_html = "<div class='placeholder'></div>"
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: for i in [
f"characters/{character}.png",
f"characters/{character}.jpg",
f"characters/{character}.jpeg",
]:
path = Path(i)
if path.exists(): if path.exists():
image_html = f'<img src="file/{get_image_cache(path)}">' try:
break image_html = f'<img src="file/{get_image_cache(path)}">'
break
except:
continue
container_html += f'{image_html} <span class="character-name">{character}</span>' container_html += f'{image_html} <span class="character-name">{character}</span>'
container_html += "</div>" container_html += "</div>"

View File

@@ -0,0 +1,51 @@
from pathlib import Path
import gradio as gr
from modules import shared
from modules import ui as _ui
params = {
'template': '%input%'
}
def get_available_templates():
return ['None'] + sorted(set((k.stem for k in Path('extensions/prompt_template/templates').glob('*.txt'))), key=str.lower)
def load_template(fname):
if fname in ['None', '']:
return '%input%'
else:
with open(Path(f'extensions/prompt_template/templates/{fname}.txt'), 'r', encoding='utf-8') as f:
text = f.read()
if text[-1] == '\n':
text = text[:-1]
return text
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
return params['template'].replace('%input%', string)
def output_modifier(string):
return f'\n{string}'
def setup():
shared.args.verbose = True
def ui():
# Gradio elements
with gr.Row():
with gr.Column():
template = gr.Textbox(value=params['template'], info="%input% will be replaced with your user input.", label='Template')
with gr.Column():
with gr.Row():
template_menu = gr.Dropdown(choices=get_available_templates(), value='None', label='Available templates')
_ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_templates()}, 'refresh-button')
template_menu.change(load_template, template_menu, template)
template.change(lambda x: params.update({"template": x}), template, None)

View File

@@ -1,6 +1,5 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request. Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction: ### Instruction:
Write a poem about the transformers Python library. %input%
Mention the word "large language models" in that poem.
### Response: ### Response:

View File

@@ -0,0 +1 @@
<|prompter|>%input%<|endoftext|><|assistant|>

View File

@@ -176,4 +176,4 @@ def ui():
force_btn.click(force_pic) force_btn.click(force_pic)
generate_now_btn.click(force_pic) generate_now_btn.click(force_pic)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

View File

@@ -2,11 +2,12 @@ import base64
from io import BytesIO from io import BytesIO
import gradio as gr import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor from transformers import BlipForConditionalGeneration, BlipProcessor
from modules import chat, shared
# If 'state' is True, will hijack the next chat generation with # If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text] # custom input text given by 'value' in the format [text, visible_text]
input_hijack = { input_hijack = {
@@ -35,11 +36,13 @@ def generate_chat_picture(picture, name1, name2):
def ui(): def ui():
picture_select = gr.Image(label='Send a picture', type='pil') picture_select = gr.Image(label='Send a picture', type='pil')
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
# Prepare the hijack with custom inputs # Prepare the hijack with custom inputs
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
# Call the generation function # Call the generation function
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field # Clear the picture from the upload field
picture_select.upload(lambda : None, [], [picture_select], show_progress=False) picture_select.upload(lambda : None, [], [picture_select], show_progress=False)

View File

@@ -52,7 +52,7 @@ def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name
name = model_name.lower() name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])): if any((k in name for k in ['llama', 'alpaca'])):
model_type = 'llama' model_type = 'llama'
elif any((k in name for k in ['opt-', 'galactica'])): elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt' model_type = 'opt'
@@ -65,18 +65,16 @@ def load_quantized(model_name):
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()
if shared.args.pre_layer and model_type == 'llama': if model_type == 'llama' and shared.args.pre_layer:
load_quant = llama_inference_offload.load_quant load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'): elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant load_quant = _load_quant
else: else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit() exit()
# Now we are going to try to locate the quantized model file. # Now we are going to try to locate the quantized model file.
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'models/{model_name}')
found_pts = list(path_to_model.glob("*.pt")) found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors")) found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None pt_path = None
@@ -97,8 +95,8 @@ def load_quantized(model_name):
else: else:
pt_model = f'{model_name}-{shared.args.wbits}bit' pt_model = f'{model_name}-{shared.args.wbits}bit'
# Try to find the .safetensors or .pt both in the model dir and in the subfolder # Try to find the .safetensors or .pt both in models/ and in the subfolder
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists(): if path.exists():
print(f"Found {path}") print(f"Found {path}")
pt_path = path pt_path = path
@@ -109,7 +107,7 @@ def load_quantized(model_name):
exit() exit()
# qwopqwop200's offload # qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer: if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
threshold = False if model_type == 'gptj' else 128 threshold = False if model_type == 'gptj' else 128

View File

@@ -1,38 +0,0 @@
import json
import gradio as gr
from modules import shared
from modules.text_generation import generate_reply
def generate_reply_wrapper(string):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 50,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
}
params = json.loads(string)
for k in params[1]:
generate_params[k] = params[1][k]
for i in generate_reply(params[0], generate_params):
yield i
def create_apis():
t1 = gr.Textbox(visible=False)
t2 = gr.Textbox(visible=False)
dummy = gr.Button(visible=False)
input_params = [t1]
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

View File

@@ -12,56 +12,45 @@ from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
import modules.shared as shared import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import (fix_newlines, chat_html_wrapper, from modules.html_generator import fix_newlines, generate_chat_html
make_thumbnail)
from modules.text_generation import (encode, generate_reply, from modules.text_generation import (encode, generate_reply,
get_max_prompt_length) get_max_prompt_length)
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs): def generate_chat_output(history, name1, name2, character):
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False if shared.args.cai_chat:
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' return generate_chat_html(history, name1, name2, character)
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False else:
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False return history
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
user_input = fix_newlines(user_input) user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
# Finding the maximum prompt size
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
if is_instruct:
prefix1 = f"{name1}\n"
prefix2 = f"{name2}\n"
else:
prefix1 = f"{name1}: "
prefix2 = f"{name2}: "
i = len(shared.history['internal'])-1 i = len(shared.history['internal'])-1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
string = shared.history['internal'][i][0] prev_user_input = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
i -= 1 i -= 1
if impersonate: if not impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2
else:
# Adding the user message
if len(user_input) > 0: if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{end_of_turn}\n") rows.append(f"{name1}: {user_input}\n")
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3 limit = 3
else:
rows.append(f"{name1}:")
limit = 2
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)
if also_return_rows: if also_return_rows:
@@ -96,9 +85,9 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
reply = fix_newlines(reply) reply = fix_newlines(reply)
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
just_started = True just_started = True
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if stop_at_newline else None
name1_original = name1 name1_original = name1
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
@@ -115,13 +104,14 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
if shared.args.chat:
visible_text = visible_text.replace('\n', '<br>')
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
else: else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
# Yield *Is typing...* # Yield *Is typing...*
if not regenerate: if not regenerate:
@@ -129,15 +119,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
# Generate # Generate
cumulative_reply = '' cumulative_reply = ''
for i in range(generate_state['chat_generation_attempts']): for i in range(chat_generation_attempts):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply reply = cumulative_reply + reply
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat:
visible_reply = visible_reply.replace('\n', '<br>')
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
@@ -160,23 +152,23 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if stop_at_newline else None
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn) prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
cumulative_reply = '' cumulative_reply = ''
for i in range(generate_state['chat_generation_attempts']): for i in range(chat_generation_attempts):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
yield reply yield reply
if next_character_found: if next_character_found:
break break
@@ -186,30 +178,36 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
yield chat_html_wrapper(history, name1, name2, mode) yield generate_chat_html(history, name1, name2, shared.character)
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
else: else:
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*' # Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode) yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True): for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] if shared.args.cai_chat:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
else:
shared.history['visible'][-1] = (last_visible[0], history[-1][1])
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = shared.history['visible'].pop()
shared.history['internal'].pop() shared.history['internal'].pop()
else: else:
last = ['', ''] last = ['', '']
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0] if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
else:
return shared.history['visible'], last[0]
def send_last_reply_to_input(): def send_last_reply_to_input():
if len(shared.history['internal']) > 0: if len(shared.history['internal']) > 0:
@@ -217,17 +215,20 @@ def send_last_reply_to_input():
else: else:
return '' return ''
def replace_last_reply(text, name1, name2, mode): def replace_last_reply(text, name1, name2):
if len(shared.history['visible']) > 0: if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text if shared.args.cai_chat:
shared.history['visible'][-1][1] = text
else:
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
shared.history['internal'][-1][1] = apply_extensions(text, "input") shared.history['internal'][-1][1] = apply_extensions(text, "input")
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def clear_html(): def clear_html():
return chat_html_wrapper([], "", "") return generate_chat_html([], "", "", shared.character)
def clear_chat_log(name1, name2, greeting, mode): def clear_chat_log(name1, name2, greeting):
shared.history['visible'] = [] shared.history['visible'] = []
shared.history['internal'] = [] shared.history['internal'] = []
@@ -235,12 +236,12 @@ def clear_chat_log(name1, name2, greeting, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def redraw_html(name1, name2, mode): def redraw_html(name1, name2):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
def tokenize_dialogue(dialogue, name1, name2, mode): def tokenize_dialogue(dialogue, name1, name2):
history = [] history = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
@@ -325,35 +326,15 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
return context return context
def generate_pfp_cache(character): def load_character(character, name1, name2):
cache_folder = Path("cache")
if not cache_folder.exists():
cache_folder.mkdir()
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
if path.exists():
img = make_thumbnail(Image.open(path))
img.save(Path('cache/pfp_character.png'), format='PNG')
return img
return None
def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
shared.history['internal'] = [] shared.history['internal'] = []
shared.history['visible'] = [] shared.history['visible'] = []
context = greeting = end_of_turn = "" greeting = ""
greeting_field = 'greeting'
picture = None
# Deleting the profile picture cache, if any
if Path("cache/pfp_character.png").exists():
Path("cache/pfp_character.png").unlink()
if character != 'None': if character != 'None':
folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
picture = generate_pfp_cache(character)
for extension in ["yml", "yaml", "json"]: for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}') filepath = Path(f'characters/{character}.{extension}')
if filepath.exists(): if filepath.exists():
break break
file_contents = open(filepath, 'r', encoding='utf-8').read() file_contents = open(filepath, 'r', encoding='utf-8').read()
@@ -369,21 +350,19 @@ def load_character(character, name1, name2, mode):
if 'context' in data: if 'context' in data:
context = f"{data['context'].strip()}\n\n" context = f"{data['context'].strip()}\n\n"
elif "char_persona" in data: greeting_field = 'greeting'
else:
context = build_pygmalion_style_context(data) context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting' greeting_field = 'char_greeting'
if 'example_dialogue' in data: if 'example_dialogue' in data and data['example_dialogue'] != '':
context += f"{data['example_dialogue'].strip()}\n" context += f"{data['example_dialogue'].strip()}\n"
if greeting_field in data: if greeting_field in data and len(data[greeting_field].strip()) > 0:
greeting = data[greeting_field] greeting = data[greeting_field]
if 'end_of_turn' in data:
end_of_turn = data['end_of_turn']
else: else:
context = shared.settings['context'] context = shared.settings['context']
name2 = shared.settings['name2'] name2 = shared.settings['name2']
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn']
if Path(f'logs/{shared.character}_persistent.json').exists(): if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
@@ -391,10 +370,13 @@ def load_character(character, name1, name2, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) if shared.args.cai_chat:
return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return name1, name2, greeting, context, shared.history['visible']
def load_default_history(name1, name2): def load_default_history(name1, name2):
load_character("None", name1, name2, "chat") load_character("None", name1, name2)
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
@@ -422,17 +404,7 @@ def upload_tavern_character(img, name1, name2):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True) return upload_character(json.dumps(_json), img, tavern=True)
def upload_your_profile_picture(img, name1, name2, mode): def upload_your_profile_picture(img):
cache_folder = Path("cache") img = Image.open(io.BytesIO(img))
if not cache_folder.exists(): img.save(Path('img_me.png'))
cache_folder.mkdir() print('Profile picture saved to "img_me.png"')
if img == None:
if Path("cache/pfp_me.png").exists():
Path("cache/pfp_me.png").unlink()
else:
img = make_thumbnail(img)
img.save(Path('cache/pfp_me.png'))
print('Profile picture saved to "cache/pfp_me.png"')
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)

View File

@@ -6,11 +6,10 @@ This is a library for formatting text outputs as nice HTML.
import os import os
import re import re
import time
from pathlib import Path from pathlib import Path
import markdown import markdown
from PIL import Image, ImageOps from PIL import Image
# This is to store the paths to the thumbnails of the profile pictures # This is to store the paths to the thumbnails of the profile pictures
image_cache = {} image_cache = {}
@@ -21,8 +20,6 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
_4chan_css = css_f.read() _4chan_css = css_f.read()
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
cai_css = f.read() cai_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
instruct_css = f.read()
def fix_newlines(string): def fix_newlines(string):
string = string.replace('\n', '\n\n') string = string.replace('\n', '\n\n')
@@ -98,13 +95,6 @@ def generate_4chan_html(f):
return output return output
def make_thumbnail(image):
image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
if image.size[1] > 470:
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
return image
def get_image_cache(path): def get_image_cache(path):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@@ -112,52 +102,26 @@ def get_image_cache(path):
mtime = os.stat(path).st_mtime mtime = os.stat(path).st_mtime
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache): if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
img = make_thumbnail(Image.open(path)) img = Image.open(path)
img.thumbnail((200, 200))
output_file = Path(f'cache/{path.name}_cache.png') output_file = Path(f'cache/{path.name}_cache.png')
img.convert('RGB').save(output_file, format='PNG') img.convert('RGB').save(output_file, format='PNG')
image_cache[path] = [mtime, output_file.as_posix()] image_cache[path] = [mtime, output_file.as_posix()]
return image_cache[path][1] return image_cache[path][1]
def generate_instruct_html(history): def load_html_image(paths):
output = f'<style>{instruct_css}</style><div class="chat" id="chat">' for str_path in paths:
for i,_row in enumerate(history[::-1]): path = Path(str_path)
row = [convert_to_markdown(entry) for entry in _row] if path.exists():
return f'<img src="file/{get_image_cache(path)}">'
return ''
output += f""" def generate_chat_html(history, name1, name2, character):
<div class="assistant-message">
<div class="text">
<div class="message-body">
{row[1]}
</div>
</div>
</div>
"""
if len(row[0]) == 0: # don't display empty user messages
continue
output += f"""
<div class="user-message">
<div class="text">
<div class="message-body">
{row[0]}
</div>
</div>
</div>
"""
output += "</div>"
return output
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
# The time.time() is to prevent the brower from caching the image img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
suffix = f"?{time.time()}" if reset_cache else f"?{name2}" img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
for i,_row in enumerate(history[::-1]): for i,_row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row] row = [convert_to_markdown(entry) for entry in _row]
@@ -199,16 +163,3 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output += "</div>" output += "</div>"
return output return output
def generate_chat_html(history, name1, name2):
return generate_cai_chat_html(history, name1, name2)
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache)
elif mode == "chat":
return generate_chat_html(history, name1, name2)
elif mode == "instruct":
return generate_instruct_html(history)
else:
return ''

View File

@@ -42,7 +42,7 @@ def load_model(model_name):
t0 = time.time() t0 = time.time()
shared.is_RWKV = 'rwkv-' in model_name.lower() shared.is_RWKV = 'rwkv-' in model_name.lower()
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0 shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
# Default settings # Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
@@ -105,7 +105,7 @@ def load_model(model_name):
elif shared.is_llamacpp: elif shared.is_llamacpp:
from modules.llamacpp_model import LlamaCppModel from modules.llamacpp_model import LlamaCppModel
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0] model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n") print(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model, tokenizer = LlamaCppModel.from_pretrained(model_file)

View File

@@ -33,7 +33,6 @@ settings = {
'name2': 'Assistant', 'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': 'Hello there!', 'greeting': 'Hello there!',
'end_of_turn': '',
'stop_at_newline': False, 'stop_at_newline': False,
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
@@ -45,7 +44,6 @@ settings = {
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {
'default': 'NovelAI-Sphinx Moth', 'default': 'NovelAI-Sphinx Moth',
'.*(alpaca|llama)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter', '.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive', '.*RWKV': 'Naive',
}, },
@@ -75,8 +73,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
# Basic settings # Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
@@ -133,17 +131,12 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
args = parser.parse_args() args = parser.parse_args()
# Deprecation warnings for parameters that have been renamed # Provisional, this will be deleted later
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]} deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
for k in deprecated_dict: for k in deprecated_dict:
if eval(f"args.{k}") != deprecated_dict[k][1]: if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
exec(f"args.{deprecated_dict[k][0]} = args.{k}") exec(f"args.{deprecated_dict[k][0]} = args.{k}")
# Deprecation warnings for parameters that have been removed
if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.")
args.chat = True
def is_chat(): def is_chat():
return args.chat return any((args.chat, args.cai_chat))

View File

@@ -102,11 +102,10 @@ def set_manual_seed(seed):
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(generate_state['seed']) set_manual_seed(seed)
shared.stop_everything = False shared.stop_everything = False
generate_params = {}
t0 = time.time() t0 = time.time()
original_question = question original_question = question
@@ -118,12 +117,9 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k]
generate_params["token_count"] = generate_state["max_new_tokens"]
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
output = original_question+reply output = original_question+reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
@@ -134,7 +130,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# RWKV has proper streaming, which is very nice. # RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time. # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, **generate_params): for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
output = original_question+reply output = original_question+reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
@@ -149,7 +145,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return return
input_ids = encode(question, generate_state['max_new_tokens']) input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
@@ -162,21 +158,33 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings] t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
generate_params["max_new_tokens"] = generate_state['max_new_tokens'] generate_params = {}
if not shared.args.flexgen: if not shared.args.flexgen:
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]: generate_params.update({
generate_params[k] = generate_state[k] "max_new_tokens": max_new_tokens,
generate_params["eos_token_id"] = eos_token_ids "eos_token_id": eos_token_ids,
generate_params["stopping_criteria"] = stopping_criteria_list "stopping_criteria": stopping_criteria_list,
if shared.args.no_stream: "do_sample": do_sample,
generate_params["min_length"] = 0 "temperature": temperature,
"top_p": top_p,
"typical_p": typical_p,
"repetition_penalty": repetition_penalty,
"encoder_repetition_penalty": encoder_repetition_penalty,
"top_k": top_k,
"min_length": min_length if shared.args.no_stream else 0,
"no_repeat_ngram_size": no_repeat_ngram_size,
"num_beams": num_beams,
"penalty_alpha": penalty_alpha,
"length_penalty": length_penalty,
"early_stopping": early_stopping,
})
else: else:
for k in ["do_sample", "temperature"]: generate_params.update({
generate_params[k] = generate_state[k] "max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
generate_params["stop"] = generate_state["eos_token_ids"][-1] "do_sample": do_sample,
if not shared.args.no_stream: "temperature": temperature,
generate_params["max_new_tokens"] = 8 "stop": eos_token_ids[-1],
})
if shared.args.no_cache: if shared.args.no_cache:
generate_params.update({"use_cache": False}) generate_params.update({"use_cache": False})
if shared.args.deepspeed: if shared.args.deepspeed:
@@ -236,7 +244,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
for i in range(generate_state['max_new_tokens']//8+1): for i in range(max_new_tokens//8+1):
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = shared.model.generate(**generate_params)[0] output = shared.model.generate(**generate_params)[0]

View File

@@ -1,6 +0,0 @@
do_sample=True
top_p=0.1
top_k=40
temperature=0.7
repetition_penalty=1.18
typical_p=1.0

View File

@@ -1 +0,0 @@
<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>

View File

@@ -13,4 +13,4 @@ safetensors==0.3.0
sentencepiece sentencepiece
pyyaml pyyaml
tqdm tqdm
git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0 git+https://github.com/huggingface/transformers

121
server.py
View File

@@ -1,7 +1,3 @@
import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import io import io
import json import json
import re import re
@@ -12,11 +8,10 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import chat, shared, training, ui, api from modules import chat, shared, training, ui
from modules.html_generator import chat_html_wrapper from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import (clear_torch_cache, generate_reply, from modules.text_generation import (clear_torch_cache, generate_reply,
@@ -52,10 +47,6 @@ def get_available_prompts():
def get_available_characters(): def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
def get_available_instruction_templates():
paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
def get_available_extensions(): def get_available_extensions():
@@ -85,7 +76,7 @@ def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
return selected_lora return selected_lora
def load_preset_values(preset_menu, state, return_dict=False): def load_preset_values(preset_menu, return_dict=False):
generate_params = { generate_params = {
'do_sample': True, 'do_sample': True,
'temperature': 1, 'temperature': 1,
@@ -107,13 +98,13 @@ def load_preset_values(preset_menu, state, return_dict=False):
i = i.rstrip(',').strip().split('=') i = i.rstrip(',').strip().split('=')
if len(i) == 2 and i[0].strip() != 'tokens': if len(i) == 2 and i[0].strip() != 'tokens':
generate_params[i[0].strip()] = eval(i[1].strip()) generate_params[i[0].strip()] = eval(i[1].strip())
generate_params['temperature'] = min(1.99, generate_params['temperature']) generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict: if return_dict:
return generate_params return generate_params
else: else:
state.update(generate_params) return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
def upload_soft_prompt(file): def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf: with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -139,7 +130,7 @@ def create_model_and_preset_menus():
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
def save_prompt(text): def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
return f"Saved to prompts/{fname}" return f"Saved to prompts/{fname}"
@@ -153,7 +144,7 @@ def load_prompt(fname):
if text[-1] == '\n': if text[-1] == '\n':
text = text[:-1] text = text[:-1]
return text return text
def create_prompt_menus(): def create_prompt_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -170,10 +161,7 @@ def create_prompt_menus():
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -224,16 +212,17 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"] modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args) cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
#int_list = [k for k in cmd_list if type(k) is int]
shared.args.extensions = extensions shared.args.extensions = extensions
for k in modes[1:]: for k in modes[1:]:
@@ -306,7 +295,10 @@ def create_interface():
if shared.is_chat(): if shared.is_chat():
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate') shared.gradio['Generate'] = gr.Button('Generate')
@@ -323,20 +315,11 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
with gr.Row(): shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
with gr.Column(scale=8): shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@@ -364,6 +347,8 @@ def create_interface():
gr.Markdown("# TavernAI PNG format") gr.Markdown("# TavernAI PNG format")
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab('Upload your profile picture'):
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
with gr.Box(): with gr.Box():
@@ -374,35 +359,35 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column(): with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']] function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
def set_chat_input(textbox): def set_chat_input(textbox):
return textbox, "" return textbox, ""
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation # Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']) shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
@@ -414,21 +399,20 @@ def create_interface():
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display']) shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']] reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']]) reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@@ -458,9 +442,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@@ -491,9 +475,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
@@ -526,21 +510,6 @@ def create_interface():
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat():
api.create_apis()
# Authentication # Authentication
auth = None auth = None
if shared.args.gradio_auth_path is not None: if shared.args.gradio_auth_path is not None: