diff --git a/README.md b/README.md
index 373f83f..c4dd01d 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* Dropdown menu for switching between models
* Notebook mode that resembles OpenAI's playground
* Chat mode for conversation and role playing
+* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
* Nice HTML output for GPT-4chan
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
@@ -26,11 +27,11 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* CPU mode
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
-* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
+* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
-* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
+* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
* Softprompts
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
@@ -62,7 +63,7 @@ Recommended if you have some experience with the command-line.
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
-0. Install Conda
+#### 0. Install Conda
https://docs.conda.io/en/latest/miniconda.html
@@ -75,14 +76,14 @@ bash Miniconda3.sh
Source: https://educe-ubc.github.io/conda.html
-1. Create a new conda environment
+#### 1. Create a new conda environment
```
conda create -n textgen python=3.10.9
conda activate textgen
```
-2. Install Pytorch
+#### 2. Install Pytorch
| System | GPU | Command |
|--------|---------|---------|
@@ -92,10 +93,12 @@ conda activate textgen
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
-MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
+#### 2.1 Special instructions
+* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
+* AMD users: https://rentry.org/eq3hg
-3. Install the web UI
+#### 3. Install the web UI
```
git clone https://github.com/oobabooga/text-generation-webui
@@ -116,6 +119,15 @@ As an alternative to the recommended WSL method, you can install the web UI nati
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
+### Updating the requirements
+
+From time to time, the `requirements.txt` changes. To update, use this command:
+
+```
+conda activate textgen
+cd text-generation-webui
+pip install -r requirements.txt --upgrade
+```
## Downloading models
Models should be placed inside the `models` folder.
@@ -175,7 +187,6 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | show this help message and exit |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.|
-| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--model-dir MODEL_DIR` | Path to directory with all the models |
diff --git a/api-example-stream.py b/api-example-stream.py
index e87fb74..32eefc7 100644
--- a/api-example-stream.py
+++ b/api-example-stream.py
@@ -36,6 +36,7 @@ async def run(context):
'early_stopping': False,
'seed': -1,
}
+ payload = json.dumps([context, params])
session = random_hash()
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -54,22 +55,7 @@ async def run(context):
"session_hash": session,
"fn_index": 12,
"data": [
- context,
- params['max_new_tokens'],
- params['do_sample'],
- params['temperature'],
- params['top_p'],
- params['typical_p'],
- params['repetition_penalty'],
- params['encoder_repetition_penalty'],
- params['top_k'],
- params['min_length'],
- params['no_repeat_ngram_size'],
- params['num_beams'],
- params['penalty_alpha'],
- params['length_penalty'],
- params['early_stopping'],
- params['seed'],
+ payload
]
}))
case "process_starts":
diff --git a/api-example.py b/api-example.py
index 0349824..10be0a8 100644
--- a/api-example.py
+++ b/api-example.py
@@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely.
'''
+import json
+
import requests
# Server address
@@ -38,24 +40,11 @@ params = {
# Input prompt
prompt = "What I would like to say is the following: "
+payload = json.dumps([prompt, params])
+
response = requests.post(f"http://{server}:7860/run/textgen", json={
"data": [
- prompt,
- params['max_new_tokens'],
- params['do_sample'],
- params['temperature'],
- params['top_p'],
- params['typical_p'],
- params['repetition_penalty'],
- params['encoder_repetition_penalty'],
- params['top_k'],
- params['min_length'],
- params['no_repeat_ngram_size'],
- params['num_beams'],
- params['penalty_alpha'],
- params['length_penalty'],
- params['early_stopping'],
- params['seed'],
+ payload
]
}).json()
diff --git a/characters/Example.yaml b/characters/Example.yaml
index 948dece..0160f45 100644
--- a/characters/Example.yaml
+++ b/characters/Example.yaml
@@ -1,32 +1,16 @@
-name: Chiharu Yamada
-context: 'Chiharu Yamada''s Persona: Chiharu Yamada is a young, computer engineer-nerd
- with a knack for problem solving and a passion for technology.'
-greeting: '*Chiharu strides into the room with a smile, her eyes lighting up
- when she sees you. She''s wearing a light blue t-shirt and jeans, her laptop bag
- slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in
- the air*
-
- Hey! I''m so excited to finally meet you. I''ve heard so many great things about
- you and I''m eager to pick your brain about computers. I''m sure you have a wealth
- of knowledge that I can learn from. *She grins, eyes twinkling with excitement*
- Let''s get started!'
-example_dialogue: '{{user}}: So how did you get into computer engineering?
-
- {{char}}: I''ve always loved tinkering with technology since I was a kid.
-
- {{user}}: That''s really impressive!
-
+name: "Chiharu Yamada"
+context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
+greeting: |-
+ *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
+ Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
+example_dialogue: |-
+ {{user}}: So how did you get into computer engineering?
+ {{char}}: I've always loved tinkering with technology since I was a kid.
+ {{user}}: That's really impressive!
{{char}}: *She chuckles bashfully* Thanks!
-
- {{user}}: So what do you do when you''re not working on computers?
-
- {{char}}: I love exploring, going out with friends, watching movies, and playing
- video games.
-
- {{user}}: What''s your favorite type of computer hardware to work with?
-
- {{char}}: Motherboards, they''re like puzzles and the backbone of any system.
-
+ {{user}}: So what do you do when you're not working on computers?
+ {{char}}: I love exploring, going out with friends, watching movies, and playing video games.
+ {{user}}: What's your favorite type of computer hardware to work with?
+ {{char}}: Motherboards, they're like puzzles and the backbone of any system.
{{user}}: That sounds great!
-
- {{char}}: Yeah, it''s really fun. I''m lucky to be able to do this as a job.'
+ {{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.
diff --git a/characters/instruction-following/Alpaca.yaml b/characters/instruction-following/Alpaca.yaml
new file mode 100644
index 0000000..3037324
--- /dev/null
+++ b/characters/instruction-following/Alpaca.yaml
@@ -0,0 +1,3 @@
+name: "### Response:"
+your_name: "### Instruction:"
+context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
diff --git a/characters/instruction-following/Open Assistant.yaml b/characters/instruction-following/Open Assistant.yaml
new file mode 100644
index 0000000..5b3320f
--- /dev/null
+++ b/characters/instruction-following/Open Assistant.yaml
@@ -0,0 +1,3 @@
+name: "<|assistant|>"
+your_name: "<|prompter|>"
+end_of_turn: "<|endoftext|>"
diff --git a/characters/instruction-following/Vicuna.yaml b/characters/instruction-following/Vicuna.yaml
new file mode 100644
index 0000000..026901d
--- /dev/null
+++ b/characters/instruction-following/Vicuna.yaml
@@ -0,0 +1,3 @@
+name: "### Assistant:"
+your_name: "### Human:"
+context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
diff --git a/css/html_cai_style.css b/css/html_cai_style.css
index 3190b3d..57c3b5c 100644
--- a/css/html_cai_style.css
+++ b/css/html_cai_style.css
@@ -64,6 +64,15 @@
line-height: 1.428571429 !important;
}
+.message-body li {
+ margin-top: 0.5em !important;
+ margin-bottom: 0.5em !important;
+}
+
+.message-body li > p {
+ display: inline !important;
+}
+
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
}
diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css
new file mode 100644
index 0000000..533c547
--- /dev/null
+++ b/css/html_instruct_style.css
@@ -0,0 +1,65 @@
+.chat {
+ margin-left: auto;
+ margin-right: auto;
+ max-width: 800px;
+ height: 66.67vh;
+ overflow-y: auto;
+ padding-right: 20px;
+ display: flex;
+ flex-direction: column-reverse;
+}
+
+.message {
+ display: grid;
+ grid-template-columns: 60px 1fr;
+ padding-bottom: 25px;
+ font-size: 15px;
+ font-family: Helvetica, Arial, sans-serif;
+ line-height: 1.428571429;
+}
+
+.username {
+ display: none;
+}
+
+.message-body {}
+
+.message-body p {
+ margin-bottom: 0 !important;
+ font-size: 15px !important;
+ line-height: 1.428571429 !important;
+}
+
+.message-body li {
+ margin-top: 0.5em !important;
+ margin-bottom: 0.5em !important;
+}
+
+.message-body li > p {
+ display: inline !important;
+}
+
+.dark .message-body p em {
+ color: rgb(138, 138, 138) !important;
+}
+
+.message-body p em {
+ color: rgb(110, 110, 110) !important;
+}
+
+.gradio-container .chat .assistant-message {
+ padding: 15px;
+ border-radius: 20px;
+ background-color: #0000000f;
+ margin-bottom: 17.5px;
+}
+
+.gradio-container .chat .user-message {
+ padding: 15px;
+ border-radius: 20px;
+ margin-bottom: 17.5px !important;
+}
+
+.dark .chat .assistant-message {
+ background-color: #ffffff21;
+}
\ No newline at end of file
diff --git a/css/main.css b/css/main.css
index 6aa3bc1..2d8f01e 100644
--- a/css/main.css
+++ b/css/main.css
@@ -41,7 +41,7 @@ ol li p, ul li p {
display: inline-block;
}
-#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
+#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab {
border: 0;
}
@@ -63,3 +63,7 @@ span.math.inline {
font-size: 27px;
vertical-align: baseline !important;
}
+
+div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
+ flex-wrap: nowrap;
+}
diff --git a/extensions/api/script.py b/extensions/api/script.py
index 20562cc..6726d61 100644
--- a/extensions/api/script.py
+++ b/extensions/api/script.py
@@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
+ generate_params = {
+ 'max_new_tokens': int(body.get('max_length', 200)),
+ 'do_sample': bool(body.get('do_sample', True)),
+ 'temperature': float(body.get('temperature', 0.5)),
+ 'top_p': float(body.get('top_p', 1)),
+ 'typical_p': float(body.get('typical', 1)),
+ 'repetition_penalty': float(body.get('rep_pen', 1.1)),
+ 'encoder_repetition_penalty': 1,
+ 'top_k': int(body.get('top_k', 0)),
+ 'min_length': int(body.get('min_length', 0)),
+ 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
+ 'num_beams': int(body.get('num_beams',1)),
+ 'penalty_alpha': float(body.get('penalty_alpha', 0)),
+ 'length_penalty': float(body.get('length_penalty', 1)),
+ 'early_stopping': bool(body.get('early_stopping', False)),
+ 'seed': int(body.get('seed', -1)),
+ }
generator = generate_reply(
- question = prompt,
- max_new_tokens = int(body.get('max_length', 200)),
- do_sample=bool(body.get('do_sample', True)),
- temperature=float(body.get('temperature', 0.5)),
- top_p=float(body.get('top_p', 1)),
- typical_p=float(body.get('typical', 1)),
- repetition_penalty=float(body.get('rep_pen', 1.1)),
- encoder_repetition_penalty=1,
- top_k=int(body.get('top_k', 0)),
- min_length=int(body.get('min_length', 0)),
- no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
- num_beams=int(body.get('num_beams',1)),
- penalty_alpha=float(body.get('penalty_alpha', 0)),
- length_penalty=float(body.get('length_penalty', 1)),
- early_stopping=bool(body.get('early_stopping', False)),
- seed=int(body.get('seed', -1)),
+ prompt,
+ generate_params,
stopping_strings=body.get('stopping_strings', []),
)
diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py
index 034506d..5c47f0f 100644
--- a/extensions/gallery/script.py
+++ b/extensions/gallery/script.py
@@ -2,9 +2,8 @@ from pathlib import Path
import gradio as gr
-from modules.chat import load_character
from modules.html_generator import get_image_cache
-from modules.shared import gradio, settings
+from modules.shared import gradio
def generate_css():
@@ -64,22 +63,13 @@ def generate_html():
for file in sorted(Path("characters").glob("*")):
if file.suffix in [".json", ".yml", ".yaml"]:
character = file.stem
- container_html = f'
'
+ container_html = '
'
image_html = "
"
- for i in [
- f"characters/{character}.png",
- f"characters/{character}.jpg",
- f"characters/{character}.jpeg",
- ]:
-
- path = Path(i)
+ for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
if path.exists():
- try:
- image_html = f'
})
'
- break
- except:
- continue
+ image_html = f'
})
'
+ break
container_html += f'{image_html}
{character}'
container_html += "
"
diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py
index cc85f3b..df07ef2 100644
--- a/extensions/sd_api_pictures/script.py
+++ b/extensions/sd_api_pictures/script.py
@@ -176,4 +176,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
- generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
\ No newline at end of file
+ generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
\ No newline at end of file
diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py
index 556a88e..d2401df 100644
--- a/extensions/send_pictures/script.py
+++ b/extensions/send_pictures/script.py
@@ -2,12 +2,11 @@ import base64
from io import BytesIO
import gradio as gr
-import modules.chat as chat
-import modules.shared as shared
import torch
-from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
+from modules import chat, shared
+
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
input_hijack = {
@@ -36,13 +35,11 @@ def generate_chat_picture(picture, name1, name2):
def ui():
picture_select = gr.Image(label='Send a picture', type='pil')
- function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
-
# Prepare the hijack with custom inputs
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
# Call the generation function
- picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+ picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py
index e7877de..572947a 100644
--- a/modules/GPTQ_loader.py
+++ b/modules/GPTQ_loader.py
@@ -1,3 +1,4 @@
+import inspect
import re
import sys
from pathlib import Path
@@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
- torch.nn.init.kaiming_uniform_ = noop
- torch.nn.init.uniform_ = noop
- torch.nn.init.normal_ = noop
+ torch.nn.init.kaiming_uniform_ = noop
+ torch.nn.init.uniform_ = noop
+ torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
@@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
for name in exclude_layers:
if name in layers:
del layers[name]
- make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
+
+ gptq_args = inspect.getfullargspec(make_quant).args
+
+ make_quant_kwargs = {
+ 'module': model,
+ 'names': layers,
+ 'bits': wbits,
+ }
+ if 'groupsize' in gptq_args:
+ make_quant_kwargs['groupsize'] = groupsize
+ if 'faster' in gptq_args:
+ make_quant_kwargs['faster'] = faster_kernel
+ if 'kernel_switch_threshold' in gptq_args:
+ make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
+
+ make_quant(**make_quant_kwargs)
del layers
-
+
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
- model.load_state_dict(safe_load(checkpoint))
+ model.load_state_dict(safe_load(checkpoint), strict = False)
else:
- model.load_state_dict(torch.load(checkpoint))
+ model.load_state_dict(torch.load(checkpoint), strict = False)
model.seqlen = 2048
print('Done.')
@@ -52,7 +68,7 @@ def load_quantized(model_name):
if not shared.args.model_type:
# Try to determine model type from model name
name = model_name.lower()
- if any((k in name for k in ['llama', 'alpaca'])):
+ if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
model_type = 'llama'
elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt'
@@ -65,16 +81,18 @@ def load_quantized(model_name):
else:
model_type = shared.args.model_type.lower()
- if model_type == 'llama' and shared.args.pre_layer:
+ if shared.args.pre_layer and model_type == 'llama':
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
+ if shared.args.pre_layer:
+ print("Warning: ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Now we are going to try to locate the quantized model file.
- path_to_model = Path(f'models/{model_name}')
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None
@@ -95,8 +113,8 @@ def load_quantized(model_name):
else:
pt_model = f'{model_name}-{shared.args.wbits}bit'
- # Try to find the .safetensors or .pt both in models/ and in the subfolder
- for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+ # Try to find the .safetensors or .pt both in the model dir and in the subfolder
+ for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists():
print(f"Found {path}")
pt_path = path
@@ -107,7 +125,7 @@ def load_quantized(model_name):
exit()
# qwopqwop200's offload
- if shared.args.pre_layer:
+ if model_type == 'llama' and shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else:
threshold = False if model_type == 'gptj' else 128
diff --git a/modules/api.py b/modules/api.py
new file mode 100644
index 0000000..26249fd
--- /dev/null
+++ b/modules/api.py
@@ -0,0 +1,38 @@
+import json
+
+import gradio as gr
+
+from modules import shared
+from modules.text_generation import generate_reply
+
+
+def generate_reply_wrapper(string):
+ generate_params = {
+ 'do_sample': True,
+ 'temperature': 1,
+ 'top_p': 1,
+ 'typical_p': 1,
+ 'repetition_penalty': 1,
+ 'encoder_repetition_penalty': 1,
+ 'top_k': 50,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'min_length': 0,
+ 'length_penalty': 1,
+ 'no_repeat_ngram_size': 0,
+ 'early_stopping': False,
+ }
+ params = json.loads(string)
+ for k in params[1]:
+ generate_params[k] = params[1][k]
+ for i in generate_reply(params[0], generate_params):
+ yield i
+
+def create_apis():
+ t1 = gr.Textbox(visible=False)
+ t2 = gr.Textbox(visible=False)
+ dummy = gr.Button(visible=False)
+
+ input_params = [t1]
+ output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
+ dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')
diff --git a/modules/chat.py b/modules/chat.py
index cd8639c..3693264 100644
--- a/modules/chat.py
+++ b/modules/chat.py
@@ -12,45 +12,55 @@ from PIL import Image
import modules.extensions as extensions_module
import modules.shared as shared
from modules.extensions import apply_extensions
-from modules.html_generator import fix_newlines, generate_chat_html
+from modules.html_generator import (fix_newlines, chat_html_wrapper,
+ make_thumbnail)
from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
-def generate_chat_output(history, name1, name2, character):
- if shared.args.cai_chat:
- return generate_chat_html(history, name1, name2, character)
- else:
- return history
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
+ is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
+ end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
+ impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+ also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
-def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
- user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"]
+ # Finding the maximum prompt size
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
+ if is_instruct:
+ prefix1 = f"{name1}\n"
+ prefix2 = f"{name2}\n"
+ else:
+ prefix1 = f"{name1}: "
+ prefix2 = f"{name2}: "
+
i = len(shared.history['internal'])-1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
- rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
- prev_user_input = shared.history['internal'][i][0]
- if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
- rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
+ rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
+ string = shared.history['internal'][i][0]
+ if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
+ rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
i -= 1
- if not impersonate:
- if len(user_input) > 0:
- rows.append(f"{name1}: {user_input}\n")
- rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
- limit = 3
- else:
- rows.append(f"{name1}:")
+ if impersonate:
+ rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2
+ else:
+ # Adding the user message
+ user_input = fix_newlines(user_input)
+ if len(user_input) > 0:
+ rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
+
+ # Adding the Character prefix
+ rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
+ limit = 3
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1)
-
prompt = ''.join(rows)
if also_return_rows:
@@ -81,13 +91,20 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
if reply[-j:] == string[:j]:
reply = reply[:-j]
break
+ else:
+ continue
+ break
reply = fix_newlines(reply)
return reply, next_character_found
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
- just_started = True
- eos_token = '\n' if stop_at_newline else None
+def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
+ if mode == 'instruct':
+ stopping_strings = [f"\n{name1}", f"\n{name2}"]
+ else:
+ stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
+
+ eos_token = '\n' if generate_state['stop_at_newline'] else None
name1_original = name1
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
@@ -104,14 +121,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if visible_text is None:
visible_text = text
- if shared.args.chat:
- visible_text = visible_text.replace('\n', '
')
text = apply_extensions(text, "input")
+ kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None:
- prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+ prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
else:
- prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+ prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
# Yield *Is typing...*
if not regenerate:
@@ -119,17 +135,16 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate
cumulative_reply = ''
- for i in range(chat_generation_attempts):
+ just_started = True
+ for i in range(generate_state['chat_generation_attempts']):
reply = None
- for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply
# Extracting the reply
- reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
+ reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
visible_reply = re.sub("(
||{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output")
- if shared.args.chat:
- visible_reply = visible_reply.replace('\n', '
')
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
@@ -152,23 +167,27 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
- eos_token = '\n' if stop_at_newline else None
-
+def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+ if mode == 'instruct':
+ stopping_strings = [f"\n{name1}", f"\n{name2}"]
+ else:
+ stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
+
+ eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
- prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
+ prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
# Yield *Is typing...*
yield shared.processing_message
cumulative_reply = ''
- for i in range(chat_generation_attempts):
+ for i in range(generate_state['chat_generation_attempts']):
reply = None
- for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+ for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply
- reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
+ reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
yield reply
if next_character_found:
break
@@ -178,36 +197,30 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
yield reply
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
- for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
- yield generate_chat_html(history, name1, name2, shared.character)
+def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+ for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+ yield chat_html_wrapper(history, name1, name2, mode)
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
+def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
- yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+ yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
else:
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*'
- yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
- for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
- if shared.args.cai_chat:
- shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
- else:
- shared.history['visible'][-1] = (last_visible[0], history[-1][1])
- yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+ yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
+ for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
+ shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
+ yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
-def remove_last_message(name1, name2):
+def remove_last_message(name1, name2, mode):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop()
shared.history['internal'].pop()
else:
last = ['', '']
- if shared.args.cai_chat:
- return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
- else:
- return shared.history['visible'], last[0]
+ return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
def send_last_reply_to_input():
if len(shared.history['internal']) > 0:
@@ -215,20 +228,17 @@ def send_last_reply_to_input():
else:
return ''
-def replace_last_reply(text, name1, name2):
+def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0:
- if shared.args.cai_chat:
- shared.history['visible'][-1][1] = text
- else:
- shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
+ shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input")
- return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+ return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def clear_html():
- return generate_chat_html([], "", "", shared.character)
+ return chat_html_wrapper([], "", "")
-def clear_chat_log(name1, name2, greeting):
+def clear_chat_log(name1, name2, greeting, mode):
shared.history['visible'] = []
shared.history['internal'] = []
@@ -236,12 +246,12 @@ def clear_chat_log(name1, name2, greeting):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
- return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+ return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
-def redraw_html(name1, name2):
- return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
+def redraw_html(name1, name2, mode):
+ return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
-def tokenize_dialogue(dialogue, name1, name2):
+def tokenize_dialogue(dialogue, name1, name2, mode):
history = []
dialogue = re.sub('', '', dialogue)
@@ -326,15 +336,35 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n\n"
return context
-def load_character(character, name1, name2):
+def generate_pfp_cache(character):
+ cache_folder = Path("cache")
+ if not cache_folder.exists():
+ cache_folder.mkdir()
+
+ for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
+ if path.exists():
+ img = make_thumbnail(Image.open(path))
+ img.save(Path('cache/pfp_character.png'), format='PNG')
+ return img
+ return None
+
+def load_character(character, name1, name2, mode):
shared.character = character
shared.history['internal'] = []
shared.history['visible'] = []
- greeting = ""
+ context = greeting = end_of_turn = ""
+ greeting_field = 'greeting'
+ picture = None
+
+ # Deleting the profile picture cache, if any
+ if Path("cache/pfp_character.png").exists():
+ Path("cache/pfp_character.png").unlink()
if character != 'None':
+ folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
+ picture = generate_pfp_cache(character)
for extension in ["yml", "yaml", "json"]:
- filepath = Path(f'characters/{character}.{extension}')
+ filepath = Path(f'{folder}/{character}.{extension}')
if filepath.exists():
break
file_contents = open(filepath, 'r', encoding='utf-8').read()
@@ -350,19 +380,21 @@ def load_character(character, name1, name2):
if 'context' in data:
context = f"{data['context'].strip()}\n\n"
- greeting_field = 'greeting'
- else:
+ elif "char_persona" in data:
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
- if 'example_dialogue' in data and data['example_dialogue'] != '':
+ if 'example_dialogue' in data:
context += f"{data['example_dialogue'].strip()}\n"
- if greeting_field in data and len(data[greeting_field].strip()) > 0:
+ if greeting_field in data:
greeting = data[greeting_field]
+ if 'end_of_turn' in data:
+ end_of_turn = data['end_of_turn']
else:
context = shared.settings['context']
name2 = shared.settings['name2']
greeting = shared.settings['greeting']
+ end_of_turn = shared.settings['end_of_turn']
if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
@@ -370,13 +402,10 @@ def load_character(character, name1, name2):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
- if shared.args.cai_chat:
- return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
- else:
- return name1, name2, greeting, context, shared.history['visible']
+ return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
def load_default_history(name1, name2):
- load_character("None", name1, name2)
+ load_character("None", name1, name2, "chat")
def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
@@ -404,7 +433,17 @@ def upload_tavern_character(img, name1, name2):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True)
-def upload_your_profile_picture(img):
- img = Image.open(io.BytesIO(img))
- img.save(Path('img_me.png'))
- print('Profile picture saved to "img_me.png"')
+def upload_your_profile_picture(img, name1, name2, mode):
+ cache_folder = Path("cache")
+ if not cache_folder.exists():
+ cache_folder.mkdir()
+
+ if img == None:
+ if Path("cache/pfp_me.png").exists():
+ Path("cache/pfp_me.png").unlink()
+ else:
+ img = make_thumbnail(img)
+ img.save(Path('cache/pfp_me.png'))
+ print('Profile picture saved to "cache/pfp_me.png"')
+
+ return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
diff --git a/modules/html_generator.py b/modules/html_generator.py
index 48d2e02..448c20c 100644
--- a/modules/html_generator.py
+++ b/modules/html_generator.py
@@ -6,10 +6,11 @@ This is a library for formatting text outputs as nice HTML.
import os
import re
+import time
from pathlib import Path
import markdown
-from PIL import Image
+from PIL import Image, ImageOps
# This is to store the paths to the thumbnails of the profile pictures
image_cache = {}
@@ -20,6 +21,8 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
_4chan_css = css_f.read()
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
cai_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
+ instruct_css = f.read()
def fix_newlines(string):
string = string.replace('\n', '\n\n')
@@ -95,6 +98,13 @@ def generate_4chan_html(f):
return output
+def make_thumbnail(image):
+ image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
+ if image.size[1] > 470:
+ image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
+
+ return image
+
def get_image_cache(path):
cache_folder = Path("cache")
if not cache_folder.exists():
@@ -102,26 +112,52 @@ def get_image_cache(path):
mtime = os.stat(path).st_mtime
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
- img = Image.open(path)
- img.thumbnail((200, 200))
+ img = make_thumbnail(Image.open(path))
output_file = Path(f'cache/{path.name}_cache.png')
img.convert('RGB').save(output_file, format='PNG')
image_cache[path] = [mtime, output_file.as_posix()]
return image_cache[path][1]
-def load_html_image(paths):
- for str_path in paths:
- path = Path(str_path)
- if path.exists():
- return f'
'
- return ''
+def generate_instruct_html(history):
+ output = f''
+ for i,_row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
-def generate_chat_html(history, name1, name2, character):
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+
+ return output
+
+def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f''
- img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
- img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
+ # The time.time() is to prevent the brower from caching the image
+ suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
+ img_bot = f'

' if Path("cache/pfp_character.png").exists() else ''
+ img_me = f'

' if Path("cache/pfp_me.png").exists() else ''
for i,_row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
@@ -163,3 +199,16 @@ def generate_chat_html(history, name1, name2, character):
output += "
"
return output
+
+def generate_chat_html(history, name1, name2):
+ return generate_cai_chat_html(history, name1, name2)
+
+def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
+ if mode == "cai-chat":
+ return generate_cai_chat_html(history, name1, name2, reset_cache)
+ elif mode == "chat":
+ return generate_chat_html(history, name1, name2)
+ elif mode == "instruct":
+ return generate_instruct_html(history)
+ else:
+ return ''
diff --git a/modules/llamacpp_model_alternative.py b/modules/llamacpp_model_alternative.py
new file mode 100644
index 0000000..4057611
--- /dev/null
+++ b/modules/llamacpp_model_alternative.py
@@ -0,0 +1,65 @@
+'''
+Based on
+https://github.com/abetlen/llama-cpp-python
+
+Documentation:
+https://abetlen.github.io/llama-cpp-python/
+'''
+
+import multiprocessing
+
+from llama_cpp import Llama
+
+from modules import shared
+from modules.callbacks import Iteratorize
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+
+ @classmethod
+ def from_pretrained(self, path):
+ result = self()
+
+ params = {
+ 'model_path': str(path),
+ 'n_ctx': 2048,
+ 'seed': 0,
+ 'n_threads': shared.args.threads or None
+ }
+ self.model = Llama(**params)
+
+ # This is ugly, but the model and the tokenizer are the same object in this library.
+ return result, result
+
+ def encode(self, string):
+ if type(string) is str:
+ string = string.encode()
+ return self.model.tokenize(string)
+
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
+ if type(context) is str:
+ context = context.encode()
+ tokens = self.model.tokenize(context)
+
+ output = b""
+ count = 0
+ for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
+ text = self.model.detokenize([token])
+ output += text
+ if callback:
+ callback(text.decode())
+
+ count += 1
+ if count >= token_count or (token == self.model.token_eos()):
+ break
+
+ return output.decode()
+
+ def generate_with_streaming(self, **kwargs):
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = ''
+ for token in generator:
+ reply += token
+ yield reply
diff --git a/modules/models.py b/modules/models.py
index edcb350..1bf6fc3 100644
--- a/modules/models.py
+++ b/modules/models.py
@@ -10,7 +10,7 @@ import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
- BitsAndBytesConfig)
+ BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
@@ -42,7 +42,7 @@ def load_model(model_name):
t0 = time.time()
shared.is_RWKV = 'rwkv-' in model_name.lower()
- shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
+ shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
# Default settings
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
@@ -103,9 +103,9 @@ def load_model(model_name):
# llamacpp model
elif shared.is_llamacpp:
- from modules.llamacpp_model import LlamaCppModel
+ from modules.llamacpp_model_alternative import LlamaCppModel
- model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
+ model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
@@ -172,6 +172,8 @@ def load_model(model_name):
# Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
+ elif type(model) is transformers.LlamaForCausalLM:
+ tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left'
diff --git a/modules/shared.py b/modules/shared.py
index 038e392..902d760 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -33,6 +33,7 @@ settings = {
'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': 'Hello there!',
+ 'end_of_turn': '',
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
@@ -44,6 +45,7 @@ settings = {
'chat_default_extensions': ["gallery"],
'presets': {
'default': 'NovelAI-Sphinx Moth',
+ '.*(alpaca|llama)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive',
},
@@ -73,8 +75,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
# Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
-parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
-parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
+parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
+parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
@@ -131,12 +133,17 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
args = parser.parse_args()
-# Provisional, this will be deleted later
+# Deprecation warnings for parameters that have been renamed
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
for k in deprecated_dict:
if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
+# Deprecation warnings for parameters that have been removed
+if args.cai_chat:
+ print("Warning: --cai-chat is deprecated. Use --chat instead.")
+ args.chat = True
+
def is_chat():
- return any((args.chat, args.cai_chat))
+ return args.chat
diff --git a/modules/text_generation.py b/modules/text_generation.py
index 406c454..b8885ab 100644
--- a/modules/text_generation.py
+++ b/modules/text_generation.py
@@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
return input_ids
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
+
+ if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
+ input_ids = input_ids[:,1:]
+
if shared.args.cpu:
return input_ids
elif shared.args.flexgen:
@@ -102,10 +106,11 @@ def set_manual_seed(seed):
def stop_everything_event():
shared.stop_everything = True
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
+def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache()
- set_manual_seed(seed)
+ set_manual_seed(generate_state['seed'])
shared.stop_everything = False
+ generate_params = {}
t0 = time.time()
original_question = question
@@ -117,9 +122,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)):
+ for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
+ generate_params[k] = generate_state[k]
+ generate_params["token_count"] = generate_state["max_new_tokens"]
try:
if shared.args.no_stream:
- reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
+ reply = shared.model.generate(context=question, **generate_params)
output = original_question+reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
@@ -130,7 +138,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
- for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
+ for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question+reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
@@ -145,7 +153,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
- input_ids = encode(question, max_new_tokens)
+ input_ids = encode(question, generate_state['max_new_tokens'])
original_input_ids = input_ids
output = input_ids[0]
@@ -158,33 +166,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
- generate_params = {}
+ generate_params["max_new_tokens"] = generate_state['max_new_tokens']
if not shared.args.flexgen:
- generate_params.update({
- "max_new_tokens": max_new_tokens,
- "eos_token_id": eos_token_ids,
- "stopping_criteria": stopping_criteria_list,
- "do_sample": do_sample,
- "temperature": temperature,
- "top_p": top_p,
- "typical_p": typical_p,
- "repetition_penalty": repetition_penalty,
- "encoder_repetition_penalty": encoder_repetition_penalty,
- "top_k": top_k,
- "min_length": min_length if shared.args.no_stream else 0,
- "no_repeat_ngram_size": no_repeat_ngram_size,
- "num_beams": num_beams,
- "penalty_alpha": penalty_alpha,
- "length_penalty": length_penalty,
- "early_stopping": early_stopping,
- })
+ for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
+ generate_params[k] = generate_state[k]
+ generate_params["eos_token_id"] = eos_token_ids
+ generate_params["stopping_criteria"] = stopping_criteria_list
+ if shared.args.no_stream:
+ generate_params["min_length"] = 0
else:
- generate_params.update({
- "max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
- "do_sample": do_sample,
- "temperature": temperature,
- "stop": eos_token_ids[-1],
- })
+ for k in ["do_sample", "temperature"]:
+ generate_params[k] = generate_state[k]
+ generate_params["stop"] = generate_state["eos_token_ids"][-1]
+ if not shared.args.no_stream:
+ generate_params["max_new_tokens"] = 8
+
if shared.args.no_cache:
generate_params.update({"use_cache": False})
if shared.args.deepspeed:
@@ -244,7 +240,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
- for i in range(max_new_tokens//8+1):
+ for i in range(generate_state['max_new_tokens']//8+1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
diff --git a/modules/training.py b/modules/training.py
index 5ba8d35..220428b 100644
--- a/modules/training.py
+++ b/modules/training.py
@@ -20,7 +20,7 @@ MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str):
- return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
+ return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
@@ -45,22 +45,26 @@ def create_train_interface():
with gr.Row():
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
- eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
+ eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
+
with gr.Tab(label="Raw Text File"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
- overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
+ with gr.Row():
+ overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
+ newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
with gr.Row():
start_button = gr.Button("Start LoRA Training")
stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready")
- start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
+ start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
+ cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt():
@@ -91,8 +95,8 @@ def clean_path(base_path: str, path: str):
return path
return f'{Path(base_path).absolute()}/{path}'
-def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
- lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
+def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
+ cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
WANT_INTERRUPT = False
CURRENT_STEPS = 0
@@ -103,6 +107,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
actual_lr = float(learning_rate)
+ model_type = type(shared.model).__name__
+ if model_type != "LlamaForCausalLM":
+ if model_type == "PeftModelForCausalLM":
+ yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
+ else:
+ yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
+ time.sleep(5)
+
+ if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
+ yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
+ return
+
+ elif not shared.args.load_in_8bit:
+ yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
+ print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
+ time.sleep(2) # Give it a moment for the message to show in UI before continuing
+
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes."
return
@@ -126,15 +149,20 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
del tokens
- data = Dataset.from_list([tokenize(x) for x in text_chunks])
- train_data = data.shuffle()
- eval_data = None
+
+ if newline_favor_len > 0:
+ text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
+
+ train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks
+ train_data = train_data.shuffle()
+ eval_data = None
else:
if dataset in ['None', '']:
@@ -232,33 +260,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
# TODO: save/load checkpoints to resume from?
print("Starting training...")
yield "Starting..."
+ if WANT_INTERRUPT:
+ yield "Interrupted before start."
+ return
- def threadedRun():
+ def threaded_run():
trainer.train()
- thread = threading.Thread(target=threadedRun)
+ thread = threading.Thread(target=threaded_run)
thread.start()
- lastStep = 0
- startTime = time.perf_counter()
+ last_step = 0
+ start_time = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if WANT_INTERRUPT:
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
- elif CURRENT_STEPS != lastStep:
- lastStep = CURRENT_STEPS
- timeElapsed = time.perf_counter() - startTime
- if timeElapsed <= 0:
- timerInfo = ""
- totalTimeEstimate = 999
+
+ elif CURRENT_STEPS != last_step:
+ last_step = CURRENT_STEPS
+ time_elapsed = time.perf_counter() - start_time
+ if time_elapsed <= 0:
+ timer_info = ""
+ total_time_estimate = 999
else:
- its = CURRENT_STEPS / timeElapsed
+ its = CURRENT_STEPS / time_elapsed
if its > 1:
- timerInfo = f"`{its:.2f}` it/s"
+ timer_info = f"`{its:.2f}` it/s"
else:
- timerInfo = f"`{1.0/its:.2f}` s/it"
- totalTimeEstimate = (1.0/its) * (MAX_STEPS)
- yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
+ timer_info = f"`{1.0/its:.2f}` s/it"
+ total_time_estimate = (1.0/its) * (MAX_STEPS)
+ yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
print("Training complete, saving...")
lora_model.save_pretrained(lora_name)
@@ -273,3 +305,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
def split_chunks(arr, step):
for i in range(0, len(arr), step):
yield arr[i:i + step]
+
+def cut_chunk_for_newline(chunk: str, max_length: int):
+ if '\n' not in chunk:
+ return chunk
+ first_newline = chunk.index('\n')
+ if first_newline < max_length:
+ chunk = chunk[first_newline + 1:]
+ if '\n' not in chunk:
+ return chunk
+ last_newline = chunk.rindex('\n')
+ if len(chunk) - last_newline < max_length:
+ chunk = chunk[:last_newline]
+ return chunk
+
+def format_time(seconds: float):
+ if seconds < 120:
+ return f"`{seconds:.0f}` seconds"
+ minutes = seconds / 60
+ if minutes < 120:
+ return f"`{minutes:.0f}` minutes"
+ hours = minutes / 60
+ return f"`{hours:.0f}` hours"
diff --git a/presets/LLaMA-Precise.txt b/presets/LLaMA-Precise.txt
new file mode 100644
index 0000000..8098b39
--- /dev/null
+++ b/presets/LLaMA-Precise.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.1
+top_k=40
+temperature=0.7
+repetition_penalty=1.18
+typical_p=1.0
diff --git a/requirements.txt b/requirements.txt
index 6d802df..aa1a38d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,12 +3,11 @@ bitsandbytes==0.37.2
datasets
flexgen==0.1.7
gradio==3.24.1
-llamacpp==0.1.11
markdown
numpy
peft==0.2.0
requests
-rwkv==0.7.2
+rwkv==0.7.3
safetensors==0.3.0
sentencepiece
pyyaml
diff --git a/server.py b/server.py
index 0a837c5..4ba5ba8 100644
--- a/server.py
+++ b/server.py
@@ -1,3 +1,7 @@
+import os
+
+os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+
import io
import json
import re
@@ -8,10 +12,11 @@ from datetime import datetime
from pathlib import Path
import gradio as gr
+from PIL import Image
import modules.extensions as extensions_module
-from modules import chat, shared, training, ui
-from modules.html_generator import generate_chat_html
+from modules import chat, shared, training, ui, api
+from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
from modules.text_generation import (clear_torch_cache, generate_reply,
@@ -47,6 +52,13 @@ def get_available_prompts():
def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+ return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
+
+def get_available_instruction_templates():
+ path = "characters/instruction-following"
+ paths = []
+ if os.path.exists(path):
+ paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
def get_available_extensions():
@@ -76,7 +88,7 @@ def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora)
return selected_lora
-def load_preset_values(preset_menu, return_dict=False):
+def load_preset_values(preset_menu, state, return_dict=False):
generate_params = {
'do_sample': True,
'temperature': 1,
@@ -98,13 +110,13 @@ def load_preset_values(preset_menu, return_dict=False):
i = i.rstrip(',').strip().split('=')
if len(i) == 2 and i[0].strip() != 'tokens':
generate_params[i[0].strip()] = eval(i[1].strip())
-
generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict:
return generate_params
else:
- return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+ state.update(generate_params)
+ return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -118,19 +130,8 @@ def upload_soft_prompt(file):
return name
-def create_model_and_preset_menus():
- with gr.Row():
- with gr.Column():
- with gr.Row():
- shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
- ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
- with gr.Column():
- with gr.Row():
- shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
- ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
-
def save_prompt(text):
- fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
+ fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text)
return f"Saved to prompts/{fname}"
@@ -144,7 +145,7 @@ def load_prompt(fname):
if text[-1] == '\n':
text = text[:-1]
return text
-
+
def create_prompt_menus():
with gr.Row():
with gr.Column():
@@ -160,12 +161,31 @@ def create_prompt_menus():
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
+def create_model_menus():
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
+ ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
+ ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
+
+ shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
+ shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
+
def create_settings_menus(default_preset):
- generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+ generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
+ for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
+ generate_params[k] = shared.settings[k]
+ shared.gradio['generate_state'] = gr.State(generate_params)
with gr.Row():
with gr.Column():
- create_model_and_preset_menus()
+ with gr.Row():
+ shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Column():
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
@@ -199,9 +219,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
- with gr.Row():
- shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
- ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
@@ -212,17 +229,14 @@ def create_settings_menus(default_preset):
with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
- shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
- shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
- shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
- shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
- shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
+ shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+ shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
+ shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
- #int_list = [k for k in cmd_list if type(k) is int]
shared.args.extensions = extensions
for k in modes[1:]:
@@ -295,10 +309,7 @@ def create_interface():
if shared.is_chat():
shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"):
- if shared.args.cai_chat:
- shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
- else:
- shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
+ shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate')
@@ -315,11 +326,20 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+ shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
+ shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
+
with gr.Tab("Character", elem_id="chat-settings"):
- shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
- shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
- shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
- shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
+ with gr.Row():
+ with gr.Column(scale=8):
+ shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
+ shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
+ shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
+ shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
+ shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
+ with gr.Column(scale=1):
+ shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
+ shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@@ -347,8 +367,6 @@ def create_interface():
gr.Markdown("# TavernAI PNG format")
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
- with gr.Tab('Upload your profile picture'):
- shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Parameters", elem_id="parameters"):
with gr.Box():
@@ -359,35 +377,35 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
- shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
+ shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset)
- function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
- shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
+ shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
def set_chat_input(textbox):
return textbox, ""
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
- gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
- gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+ gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
- shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
+ shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
- shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
+ shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+ shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
- shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
+ shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
@@ -399,20 +417,21 @@ def create_interface():
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
- shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
- shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
+ shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
+ shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
+ shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
- shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
+ shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
- reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
- reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
- shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
- shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
- shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
+ reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
+ shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
+ shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
+ shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
+ shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
- shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
+ shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
@@ -442,9 +461,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
- shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
- gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@@ -475,14 +494,17 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
- shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
+ shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
- gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+ gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
+ with gr.Tab("Model", elem_id="model-tab"):
+ create_model_menus()
+
with gr.Tab("Training", elem_id="training-tab"):
training.create_train_interface()
@@ -496,7 +518,6 @@ def create_interface():
cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
bool_active = [k for k in bool_list if vars(shared.args)[k]]
- #int_list = [k for k in cmd_list if type(k) is int]
gr.Markdown("*Experimental*")
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
@@ -510,6 +531,21 @@ def create_interface():
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
+ def change_dict_value(d, key, value):
+ d[key] = value
+ return d
+
+ for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
+ if k not in shared.gradio:
+ continue
+ if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
+ shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
+ else:
+ shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
+
+ if not shared.is_chat():
+ api.create_apis()
+
# Authentication
auth = None
if shared.args.gradio_auth_path is not None: