6 Commits

Author SHA1 Message Date
oobabooga
776b7914bf Delete Open Assistant.txt 2023-04-04 12:55:55 -03:00
oobabooga
2944c6d204 Delete Alpaca.txt 2023-04-04 12:55:47 -03:00
oobabooga
cbaa231a0a Delete dummy.txt 2023-04-04 12:55:22 -03:00
oobabooga
065383ec67 Add files via upload 2023-04-04 12:55:05 -03:00
oobabooga
214dd6307e Create dummy.txt 2023-04-04 12:54:37 -03:00
oobabooga
a500061b08 Create script.py 2023-04-04 12:54:04 -03:00
26 changed files with 323 additions and 459 deletions

View File

@@ -26,7 +26,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* CPU mode
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
@@ -62,7 +62,7 @@ Recommended if you have some experience with the command-line.
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
#### 0. Install Conda
0. Install Conda
https://docs.conda.io/en/latest/miniconda.html
@@ -75,14 +75,14 @@ bash Miniconda3.sh
Source: https://educe-ubc.github.io/conda.html
#### 1. Create a new conda environment
1. Create a new conda environment
```
conda create -n textgen python=3.10.9
conda activate textgen
```
#### 2. Install Pytorch
2. Install Pytorch
| System | GPU | Command |
|--------|---------|---------|
@@ -92,12 +92,10 @@ conda activate textgen
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
#### 2.1 Special instructions
MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
* AMD users: https://rentry.org/eq3hg
#### 3. Install the web UI
3. Install the web UI
```
git clone https://github.com/oobabooga/text-generation-webui
@@ -177,6 +175,7 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | show this help message and exit |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--model-dir MODEL_DIR` | Path to directory with all the models |

View File

@@ -36,7 +36,6 @@ async def run(context):
'early_stopping': False,
'seed': -1,
}
payload = json.dumps([context, params])
session = random_hash()
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -55,7 +54,22 @@ async def run(context):
"session_hash": session,
"fn_index": 12,
"data": [
payload
context,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
]
}))
case "process_starts":

View File

@@ -10,8 +10,6 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely.
'''
import json
import requests
# Server address
@@ -40,11 +38,24 @@ params = {
# Input prompt
prompt = "What I would like to say is the following: "
payload = json.dumps([prompt, params])
response = requests.post(f"http://{server}:7860/run/textgen", json={
"data": [
payload
prompt,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
]
}).json()

View File

@@ -1,16 +1,15 @@
name: "Chiharu Yamada"
context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
greeting: |-
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
example_dialogue: |-
{{user}}: So how did you get into computer engineering?
{{char}}: I've always loved tinkering with technology since I was a kid.
{{user}}: That's really impressive!
{{char}}: *She chuckles bashfully* Thanks!
{{user}}: So what do you do when you're not working on computers?
{{char}}: I love exploring, going out with friends, watching movies, and playing video games.
{{user}}: What's your favorite type of computer hardware to work with?
{{char}}: Motherboards, they're like puzzles and the backbone of any system.
{{user}}: That sounds great!
{{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.
{{user}}: So how did you get into computer engineering?
{{char}}: I've always loved tinkering with technology since I was a kid.
{{user}}: That's really impressive!
{{char}}: *She chuckles bashfully* Thanks!
{{user}}: So what do you do when you're not working on computers?
{{char}}: I love exploring, going out with friends, watching movies, and playing video games.
{{user}}: What's your favorite type of computer hardware to work with?
{{char}}: Motherboards, they're like puzzles and the backbone of any system.
{{user}}: That sounds great!
{{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.

View File

@@ -1,3 +0,0 @@
name: "### Response:"
your_name: "### Instruction:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

View File

@@ -1,3 +0,0 @@
name: "<|assistant|>"
your_name: "<|prompter|>"
end_of_turn: "<|endoftext|>"

View File

@@ -1,56 +0,0 @@
.chat {
margin-left: auto;
margin-right: auto;
max-width: 800px;
height: 66.67vh;
overflow-y: auto;
padding-right: 20px;
display: flex;
flex-direction: column-reverse;
}
.message {
display: grid;
grid-template-columns: 60px 1fr;
padding-bottom: 25px;
font-size: 15px;
font-family: Helvetica, Arial, sans-serif;
line-height: 1.428571429;
}
.text p {
margin-top: 5px;
}
.username {
display: none;
}
.message-body {}
.message-body p {
margin-bottom: 0 !important;
font-size: 15px !important;
line-height: 1.428571429 !important;
}
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
}
.message-body p em {
color: rgb(110, 110, 110) !important;
}
.assistant-message {
padding: 10px;
}
.user-message {
padding: 10px;
background-color: #f1f1f1;
}
.dark .user-message {
background-color: #ffffff1a;
}

View File

@@ -63,7 +63,3 @@ span.math.inline {
font-size: 27px;
vertical-align: baseline !important;
}
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
flex-wrap: nowrap;
}

View File

@@ -40,27 +40,24 @@ class Handler(BaseHTTPRequestHandler):
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
'num_beams': int(body.get('num_beams',1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
}
generator = generate_reply(
prompt,
generate_params,
question = prompt,
max_new_tokens = int(body.get('max_length', 200)),
do_sample=bool(body.get('do_sample', True)),
temperature=float(body.get('temperature', 0.5)),
top_p=float(body.get('top_p', 1)),
typical_p=float(body.get('typical', 1)),
repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1,
top_k=int(body.get('top_k', 0)),
min_length=int(body.get('min_length', 0)),
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
num_beams=int(body.get('num_beams',1)),
penalty_alpha=float(body.get('penalty_alpha', 0)),
length_penalty=float(body.get('length_penalty', 1)),
early_stopping=bool(body.get('early_stopping', False)),
seed=int(body.get('seed', -1)),
stopping_strings=body.get('stopping_strings', []),
)

View File

@@ -2,8 +2,9 @@ from pathlib import Path
import gradio as gr
from modules.chat import load_character
from modules.html_generator import get_image_cache
from modules.shared import gradio
from modules.shared import gradio, settings
def generate_css():
@@ -63,13 +64,22 @@ def generate_html():
for file in sorted(Path("characters").glob("*")):
if file.suffix in [".json", ".yml", ".yaml"]:
character = file.stem
container_html = '<div class="character-container">'
container_html = f'<div class="character-container">'
image_html = "<div class='placeholder'></div>"
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
for i in [
f"characters/{character}.png",
f"characters/{character}.jpg",
f"characters/{character}.jpeg",
]:
path = Path(i)
if path.exists():
image_html = f'<img src="file/{get_image_cache(path)}">'
break
try:
image_html = f'<img src="file/{get_image_cache(path)}">'
break
except:
continue
container_html += f'{image_html} <span class="character-name">{character}</span>'
container_html += "</div>"

View File

@@ -0,0 +1,51 @@
from pathlib import Path
import gradio as gr
from modules import shared
from modules import ui as _ui
params = {
'template': '%input%'
}
def get_available_templates():
return ['None'] + sorted(set((k.stem for k in Path('extensions/prompt_template/templates').glob('*.txt'))), key=str.lower)
def load_template(fname):
if fname in ['None', '']:
return '%input%'
else:
with open(Path(f'extensions/prompt_template/templates/{fname}.txt'), 'r', encoding='utf-8') as f:
text = f.read()
if text[-1] == '\n':
text = text[:-1]
return text
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
return params['template'].replace('%input%', string)
def output_modifier(string):
return f'\n{string}'
def setup():
shared.args.verbose = True
def ui():
# Gradio elements
with gr.Row():
with gr.Column():
template = gr.Textbox(value=params['template'], info="%input% will be replaced with your user input.", label='Template')
with gr.Column():
with gr.Row():
template_menu = gr.Dropdown(choices=get_available_templates(), value='None', label='Available templates')
_ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_templates()}, 'refresh-button')
template_menu.change(load_template, template_menu, template)
template.change(lambda x: params.update({"template": x}), template, None)

View File

@@ -1,6 +1,5 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a poem about the transformers Python library.
Mention the word "large language models" in that poem.
%input%
### Response:

View File

@@ -0,0 +1 @@
<|prompter|>%input%<|endoftext|><|assistant|>

View File

@@ -176,4 +176,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

View File

@@ -2,11 +2,12 @@ import base64
from io import BytesIO
import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
from modules import chat, shared
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
input_hijack = {
@@ -35,11 +36,13 @@ def generate_chat_picture(picture, name1, name2):
def ui():
picture_select = gr.Image(label='Send a picture', type='pil')
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
# Prepare the hijack with custom inputs
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
# Call the generation function
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)

View File

@@ -52,7 +52,7 @@ def load_quantized(model_name):
if not shared.args.model_type:
# Try to determine model type from model name
name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
if any((k in name for k in ['llama', 'alpaca'])):
model_type = 'llama'
elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt'
@@ -65,18 +65,16 @@ def load_quantized(model_name):
else:
model_type = shared.args.model_type.lower()
if shared.args.pre_layer and model_type == 'llama':
if model_type == 'llama' and shared.args.pre_layer:
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Now we are going to try to locate the quantized model file.
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
path_to_model = Path(f'models/{model_name}')
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
@@ -97,8 +95,8 @@ def load_quantized(model_name):
else:
pt_model = f'{model_name}-{shared.args.wbits}bit'
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
# Try to find the .safetensors or .pt both in models/ and in the subfolder
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
print(f"Found {path}")
pt_path = path
@@ -109,7 +107,7 @@ def load_quantized(model_name):
exit()
# qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer:
if shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else:
threshold = False if model_type == 'gptj' else 128

View File

@@ -1,38 +0,0 @@
import json
import gradio as gr
from modules import shared
from modules.text_generation import generate_reply
def generate_reply_wrapper(string):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 50,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
}
params = json.loads(string)
for k in params[1]:
generate_params[k] = params[1][k]
for i in generate_reply(params[0], generate_params):
yield i
def create_apis():
t1 = gr.Textbox(visible=False)
t2 = gr.Textbox(visible=False)
dummy = gr.Button(visible=False)
input_params = [t1]
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

View File

@@ -12,56 +12,45 @@ from PIL import Image
import modules.extensions as extensions_module
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import (fix_newlines, chat_html_wrapper,
make_thumbnail)
from modules.html_generator import fix_newlines, generate_chat_html
from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
def generate_chat_output(history, name1, name2, character):
if shared.args.cai_chat:
return generate_chat_html(history, name1, name2, character)
else:
return history
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"]
# Finding the maximum prompt size
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
if is_instruct:
prefix1 = f"{name1}\n"
prefix2 = f"{name2}\n"
else:
prefix1 = f"{name1}: "
prefix2 = f"{name2}: "
i = len(shared.history['internal'])-1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
prev_user_input = shared.history['internal'][i][0]
if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
i -= 1
if impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2
else:
# Adding the user message
if not impersonate:
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
rows.append(f"{name1}: {user_input}\n")
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
limit = 3
else:
rows.append(f"{name1}:")
limit = 2
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1)
prompt = ''.join(rows)
if also_return_rows:
@@ -96,9 +85,9 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
reply = fix_newlines(reply)
return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
just_started = True
eos_token = '\n' if generate_state['stop_at_newline'] else None
eos_token = '\n' if stop_at_newline else None
name1_original = name1
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
@@ -115,13 +104,14 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
if visible_text is None:
visible_text = text
if shared.args.chat:
visible_text = visible_text.replace('\n', '<br>')
text = apply_extensions(text, "input")
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
# Yield *Is typing...*
if not regenerate:
@@ -129,15 +119,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
# Generate
cumulative_reply = ''
for i in range(generate_state['chat_generation_attempts']):
for i in range(chat_generation_attempts):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply
# Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat:
visible_reply = visible_reply.replace('\n', '<br>')
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
@@ -160,23 +152,23 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
eos_token = '\n' if generate_state['stop_at_newline'] else None
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
eos_token = '\n' if stop_at_newline else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
# Yield *Is typing...*
yield shared.processing_message
cumulative_reply = ''
for i in range(generate_state['chat_generation_attempts']):
for i in range(chat_generation_attempts):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
yield reply
if next_character_found:
break
@@ -186,30 +178,36 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
yield chat_html_wrapper(history, name1, name2, mode)
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
yield generate_chat_html(history, name1, name2, shared.character)
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
else:
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
else:
shared.history['visible'][-1] = (last_visible[0], history[-1][1])
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def remove_last_message(name1, name2, mode):
def remove_last_message(name1, name2):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop()
shared.history['internal'].pop()
else:
last = ['', '']
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
if shared.args.cai_chat:
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
else:
return shared.history['visible'], last[0]
def send_last_reply_to_input():
if len(shared.history['internal']) > 0:
@@ -217,17 +215,20 @@ def send_last_reply_to_input():
else:
return ''
def replace_last_reply(text, name1, name2, mode):
def replace_last_reply(text, name1, name2):
if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text
if shared.args.cai_chat:
shared.history['visible'][-1][1] = text
else:
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
shared.history['internal'][-1][1] = apply_extensions(text, "input")
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def clear_html():
return chat_html_wrapper([], "", "")
return generate_chat_html([], "", "", shared.character)
def clear_chat_log(name1, name2, greeting, mode):
def clear_chat_log(name1, name2, greeting):
shared.history['visible'] = []
shared.history['internal'] = []
@@ -235,12 +236,12 @@ def clear_chat_log(name1, name2, greeting, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def redraw_html(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def redraw_html(name1, name2):
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
def tokenize_dialogue(dialogue, name1, name2, mode):
def tokenize_dialogue(dialogue, name1, name2):
history = []
dialogue = re.sub('<START>', '', dialogue)
@@ -325,35 +326,15 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n<START>\n"
return context
def generate_pfp_cache(character):
cache_folder = Path("cache")
if not cache_folder.exists():
cache_folder.mkdir()
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
if path.exists():
img = make_thumbnail(Image.open(path))
img.save(Path('cache/pfp_character.png'), format='PNG')
return img
return None
def load_character(character, name1, name2, mode):
def load_character(character, name1, name2):
shared.character = character
shared.history['internal'] = []
shared.history['visible'] = []
context = greeting = end_of_turn = ""
greeting_field = 'greeting'
picture = None
# Deleting the profile picture cache, if any
if Path("cache/pfp_character.png").exists():
Path("cache/pfp_character.png").unlink()
greeting = ""
if character != 'None':
folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
picture = generate_pfp_cache(character)
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}')
filepath = Path(f'characters/{character}.{extension}')
if filepath.exists():
break
file_contents = open(filepath, 'r', encoding='utf-8').read()
@@ -369,21 +350,19 @@ def load_character(character, name1, name2, mode):
if 'context' in data:
context = f"{data['context'].strip()}\n\n"
elif "char_persona" in data:
greeting_field = 'greeting'
else:
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
if 'example_dialogue' in data:
if 'example_dialogue' in data and data['example_dialogue'] != '':
context += f"{data['example_dialogue'].strip()}\n"
if greeting_field in data:
if greeting_field in data and len(data[greeting_field].strip()) > 0:
greeting = data[greeting_field]
if 'end_of_turn' in data:
end_of_turn = data['end_of_turn']
else:
context = shared.settings['context']
name2 = shared.settings['name2']
greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn']
if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
@@ -391,10 +370,13 @@ def load_character(character, name1, name2, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
if shared.args.cai_chat:
return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return name1, name2, greeting, context, shared.history['visible']
def load_default_history(name1, name2):
load_character("None", name1, name2, "chat")
load_character("None", name1, name2)
def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
@@ -422,17 +404,7 @@ def upload_tavern_character(img, name1, name2):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True)
def upload_your_profile_picture(img, name1, name2, mode):
cache_folder = Path("cache")
if not cache_folder.exists():
cache_folder.mkdir()
if img == None:
if Path("cache/pfp_me.png").exists():
Path("cache/pfp_me.png").unlink()
else:
img = make_thumbnail(img)
img.save(Path('cache/pfp_me.png'))
print('Profile picture saved to "cache/pfp_me.png"')
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
def upload_your_profile_picture(img):
img = Image.open(io.BytesIO(img))
img.save(Path('img_me.png'))
print('Profile picture saved to "img_me.png"')

View File

@@ -6,11 +6,10 @@ This is a library for formatting text outputs as nice HTML.
import os
import re
import time
from pathlib import Path
import markdown
from PIL import Image, ImageOps
from PIL import Image
# This is to store the paths to the thumbnails of the profile pictures
image_cache = {}
@@ -21,8 +20,6 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
_4chan_css = css_f.read()
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
cai_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
instruct_css = f.read()
def fix_newlines(string):
string = string.replace('\n', '\n\n')
@@ -98,13 +95,6 @@ def generate_4chan_html(f):
return output
def make_thumbnail(image):
image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
if image.size[1] > 470:
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
return image
def get_image_cache(path):
cache_folder = Path("cache")
if not cache_folder.exists():
@@ -112,52 +102,26 @@ def get_image_cache(path):
mtime = os.stat(path).st_mtime
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
img = make_thumbnail(Image.open(path))
img = Image.open(path)
img.thumbnail((200, 200))
output_file = Path(f'cache/{path.name}_cache.png')
img.convert('RGB').save(output_file, format='PNG')
image_cache[path] = [mtime, output_file.as_posix()]
return image_cache[path][1]
def generate_instruct_html(history):
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
for i,_row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
def load_html_image(paths):
for str_path in paths:
path = Path(str_path)
if path.exists():
return f'<img src="file/{get_image_cache(path)}">'
return ''
output += f"""
<div class="assistant-message">
<div class="text">
<div class="message-body">
{row[1]}
</div>
</div>
</div>
"""
if len(row[0]) == 0: # don't display empty user messages
continue
output += f"""
<div class="user-message">
<div class="text">
<div class="message-body">
{row[0]}
</div>
</div>
</div>
"""
output += "</div>"
return output
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
def generate_chat_html(history, name1, name2, character):
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
# The time.time() is to prevent the brower from caching the image
suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
for i,_row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
@@ -199,16 +163,3 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output += "</div>"
return output
def generate_chat_html(history, name1, name2):
return generate_cai_chat_html(history, name1, name2)
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache)
elif mode == "chat":
return generate_chat_html(history, name1, name2)
elif mode == "instruct":
return generate_instruct_html(history)
else:
return ''

View File

@@ -42,7 +42,7 @@ def load_model(model_name):
t0 = time.time()
shared.is_RWKV = 'rwkv-' in model_name.lower()
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
# Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
@@ -105,7 +105,7 @@ def load_model(model_name):
elif shared.is_llamacpp:
from modules.llamacpp_model import LlamaCppModel
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)

View File

@@ -33,7 +33,6 @@ settings = {
'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': 'Hello there!',
'end_of_turn': '',
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
@@ -45,7 +44,6 @@ settings = {
'chat_default_extensions': ["gallery"],
'presets': {
'default': 'NovelAI-Sphinx Moth',
'.*(alpaca|llama)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive',
},
@@ -75,8 +73,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
# Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
@@ -133,17 +131,12 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
args = parser.parse_args()
# Deprecation warnings for parameters that have been renamed
# Provisional, this will be deleted later
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
for k in deprecated_dict:
if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
# Deprecation warnings for parameters that have been removed
if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.")
args.chat = True
def is_chat():
return args.chat
return any((args.chat, args.cai_chat))

View File

@@ -102,11 +102,10 @@ def set_manual_seed(seed):
def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
clear_torch_cache()
set_manual_seed(generate_state['seed'])
set_manual_seed(seed)
shared.stop_everything = False
generate_params = {}
t0 = time.time()
original_question = question
@@ -118,12 +117,9 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k]
generate_params["token_count"] = generate_state["max_new_tokens"]
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params)
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
output = original_question+reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
@@ -134,7 +130,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
output = original_question+reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
@@ -149,7 +145,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
input_ids = encode(question, generate_state['max_new_tokens'])
input_ids = encode(question, max_new_tokens)
original_input_ids = input_ids
output = input_ids[0]
@@ -162,21 +158,33 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
generate_params = {}
if not shared.args.flexgen:
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
generate_params[k] = generate_state[k]
generate_params["eos_token_id"] = eos_token_ids
generate_params["stopping_criteria"] = stopping_criteria_list
if shared.args.no_stream:
generate_params["min_length"] = 0
generate_params.update({
"max_new_tokens": max_new_tokens,
"eos_token_id": eos_token_ids,
"stopping_criteria": stopping_criteria_list,
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"typical_p": typical_p,
"repetition_penalty": repetition_penalty,
"encoder_repetition_penalty": encoder_repetition_penalty,
"top_k": top_k,
"min_length": min_length if shared.args.no_stream else 0,
"no_repeat_ngram_size": no_repeat_ngram_size,
"num_beams": num_beams,
"penalty_alpha": penalty_alpha,
"length_penalty": length_penalty,
"early_stopping": early_stopping,
})
else:
for k in ["do_sample", "temperature"]:
generate_params[k] = generate_state[k]
generate_params["stop"] = generate_state["eos_token_ids"][-1]
if not shared.args.no_stream:
generate_params["max_new_tokens"] = 8
generate_params.update({
"max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
"do_sample": do_sample,
"temperature": temperature,
"stop": eos_token_ids[-1],
})
if shared.args.no_cache:
generate_params.update({"use_cache": False})
if shared.args.deepspeed:
@@ -236,7 +244,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(generate_state['max_new_tokens']//8+1):
for i in range(max_new_tokens//8+1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]

View File

@@ -1,6 +0,0 @@
do_sample=True
top_p=0.1
top_k=40
temperature=0.7
repetition_penalty=1.18
typical_p=1.0

View File

@@ -1 +0,0 @@
<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>

View File

@@ -13,4 +13,4 @@ safetensors==0.3.0
sentencepiece
pyyaml
tqdm
git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0
git+https://github.com/huggingface/transformers

121
server.py
View File

@@ -1,7 +1,3 @@
import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import io
import json
import re
@@ -12,11 +8,10 @@ from datetime import datetime
from pathlib import Path
import gradio as gr
from PIL import Image
import modules.extensions as extensions_module
from modules import chat, shared, training, ui, api
from modules.html_generator import chat_html_wrapper
from modules import chat, shared, training, ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
from modules.text_generation import (clear_torch_cache, generate_reply,
@@ -52,10 +47,6 @@ def get_available_prompts():
def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
def get_available_instruction_templates():
paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
def get_available_extensions():
@@ -85,7 +76,7 @@ def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora)
return selected_lora
def load_preset_values(preset_menu, state, return_dict=False):
def load_preset_values(preset_menu, return_dict=False):
generate_params = {
'do_sample': True,
'temperature': 1,
@@ -107,13 +98,13 @@ def load_preset_values(preset_menu, state, return_dict=False):
i = i.rstrip(',').strip().split('=')
if len(i) == 2 and i[0].strip() != 'tokens':
generate_params[i[0].strip()] = eval(i[1].strip())
generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict:
return generate_params
else:
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -139,7 +130,7 @@ def create_model_and_preset_menus():
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text)
return f"Saved to prompts/{fname}"
@@ -153,7 +144,7 @@ def load_prompt(fname):
if text[-1] == '\n':
text = text[:-1]
return text
def create_prompt_menus():
with gr.Row():
with gr.Column():
@@ -170,10 +161,7 @@ def create_prompt_menus():
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params)
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
with gr.Row():
with gr.Column():
@@ -224,16 +212,17 @@ def create_settings_menus(default_preset):
with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
#int_list = [k for k in cmd_list if type(k) is int]
shared.args.extensions = extensions
for k in modes[1:]:
@@ -306,7 +295,10 @@ def create_interface():
if shared.is_chat():
shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate')
@@ -323,20 +315,11 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
with gr.Tab("Character", elem_id="chat-settings"):
with gr.Row():
with gr.Column(scale=8):
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@@ -364,6 +347,8 @@ def create_interface():
gr.Markdown("# TavernAI PNG format")
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab('Upload your profile picture'):
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Parameters", elem_id="parameters"):
with gr.Box():
@@ -374,35 +359,35 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
def set_chat_input(textbox):
return textbox, ""
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
@@ -414,21 +399,20 @@ def create_interface():
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
@@ -458,9 +442,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@@ -491,9 +475,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
@@ -526,21 +510,6 @@ def create_interface():
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat():
api.create_apis()
# Authentication
auth = None
if shared.args.gradio_auth_path is not None: