Compare commits
6 Commits
reformat2
...
prompt_tem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
776b7914bf | ||
|
|
2944c6d204 | ||
|
|
cbaa231a0a | ||
|
|
065383ec67 | ||
|
|
214dd6307e | ||
|
|
a500061b08 |
@@ -1,10 +0,0 @@
|
||||
.env
|
||||
Dockerfile
|
||||
/characters
|
||||
/extensions
|
||||
/loras
|
||||
/models
|
||||
/presets
|
||||
/prompts
|
||||
/softprompts
|
||||
/training
|
||||
25
.env.example
25
.env.example
@@ -1,25 +0,0 @@
|
||||
# by default the Dockerfile specifies these versions: 3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX
|
||||
# however for me to work i had to specify the exact version for my card ( 2060 ) it was 7.5
|
||||
# https://developer.nvidia.com/cuda-gpus you can find the version for your card here
|
||||
TORCH_CUDA_ARCH_LIST=7.5
|
||||
|
||||
# these commands worked for me with roughly 4.5GB of vram
|
||||
CLI_ARGS=--model llama-7b-4bit --wbits 4 --listen --auto-devices
|
||||
|
||||
# the following examples have been tested with the files linked in docs/README_docker.md:
|
||||
# example running 13b with 4bit/128 groupsize : CLI_ARGS=--model llama-13b-4bit-128g --wbits 4 --listen --groupsize 128 --pre_layer 25
|
||||
# example with loading api extension and public share: CLI_ARGS=--model llama-7b-4bit --wbits 4 --listen --auto-devices --no-stream --extensions api --share
|
||||
# example running 7b with 8bit groupsize : CLI_ARGS=--model llama-7b --load-in-8bit --listen --auto-devices
|
||||
|
||||
# the port the webui binds to on the host
|
||||
HOST_PORT=7860
|
||||
# the port the webui binds to inside the container
|
||||
CONTAINER_PORT=7860
|
||||
|
||||
# the port the api binds to on the host
|
||||
HOST_API_PORT=5000
|
||||
# the port the api binds to inside the container
|
||||
CONTAINER_API_PORT=5000
|
||||
|
||||
# the version used to install text-generation-webui from
|
||||
WEBUI_VERSION=HEAD
|
||||
61
Dockerfile
61
Dockerfile
@@ -1,61 +0,0 @@
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as builder
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install --no-install-recommends -y git vim build-essential python3-dev python3-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa /build
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN python3 -m venv /build/venv
|
||||
RUN . /build/venv/bin/activate && \
|
||||
pip3 install --upgrade pip setuptools && \
|
||||
pip3 install torch torchvision torchaudio && \
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
# https://developer.nvidia.com/cuda-gpus
|
||||
# for a rtx 2060: ARG TORCH_CUDA_ARCH_LIST="7.5"
|
||||
ARG TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
|
||||
RUN . /build/venv/bin/activate && \
|
||||
python3 setup_cuda.py bdist_wheel -d .
|
||||
|
||||
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
|
||||
|
||||
LABEL maintainer="Your Name <your.email@example.com>"
|
||||
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install --no-install-recommends -y git python3 python3-pip && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
|
||||
|
||||
COPY . /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG WEBUI_VERSION
|
||||
RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Using provided webui source"
|
||||
|
||||
RUN virtualenv /app/venv
|
||||
RUN . /app/venv/bin/activate && \
|
||||
pip3 install --upgrade pip setuptools && \
|
||||
pip3 install torch torchvision torchaudio && \
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
|
||||
RUN . /app/venv/bin/activate && \
|
||||
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
|
||||
|
||||
ENV CLI_ARGS=""
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
|
||||
|
||||
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
|
||||
|
||||
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}
|
||||
38
README.md
38
README.md
@@ -15,7 +15,6 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* Dropdown menu for switching between models
|
||||
* Notebook mode that resembles OpenAI's playground
|
||||
* Chat mode for conversation and role playing
|
||||
* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
|
||||
* Nice HTML output for GPT-4chan
|
||||
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
||||
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
||||
@@ -27,11 +26,11 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* CPU mode
|
||||
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
|
||||
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
|
||||
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
|
||||
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
|
||||
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
|
||||
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
|
||||
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
|
||||
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
||||
* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
||||
* Softprompts
|
||||
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
||||
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
|
||||
@@ -63,7 +62,7 @@ Recommended if you have some experience with the command-line.
|
||||
|
||||
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
|
||||
|
||||
#### 0. Install Conda
|
||||
0. Install Conda
|
||||
|
||||
https://docs.conda.io/en/latest/miniconda.html
|
||||
|
||||
@@ -76,14 +75,14 @@ bash Miniconda3.sh
|
||||
|
||||
Source: https://educe-ubc.github.io/conda.html
|
||||
|
||||
#### 1. Create a new conda environment
|
||||
1. Create a new conda environment
|
||||
|
||||
```
|
||||
conda create -n textgen python=3.10.9
|
||||
conda activate textgen
|
||||
```
|
||||
|
||||
#### 2. Install Pytorch
|
||||
2. Install Pytorch
|
||||
|
||||
| System | GPU | Command |
|
||||
|--------|---------|---------|
|
||||
@@ -93,12 +92,10 @@ conda activate textgen
|
||||
|
||||
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
|
||||
|
||||
#### 2.1 Special instructions
|
||||
MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
|
||||
|
||||
* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
|
||||
* AMD users: https://rentry.org/eq3hg
|
||||
|
||||
#### 3. Install the web UI
|
||||
3. Install the web UI
|
||||
|
||||
```
|
||||
git clone https://github.com/oobabooga/text-generation-webui
|
||||
@@ -117,26 +114,8 @@ As an alternative to the recommended WSL method, you can install the web UI nati
|
||||
|
||||
### Alternative: Docker
|
||||
|
||||
```
|
||||
cp .env.example .env
|
||||
docker-compose up --build
|
||||
```
|
||||
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
|
||||
|
||||
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
|
||||
|
||||
You need to have docker compose v2.17 or higher installed in your system. For installation instructions, see [Docker compose installation](https://github.com/oobabooga/text-generation-webui/wiki/Docker-compose-installation).
|
||||
|
||||
Contributed by [@loeken](https://github.com/loeken) in [#633](https://github.com/oobabooga/text-generation-webui/pull/633)
|
||||
|
||||
### Updating the requirements
|
||||
|
||||
From time to time, the `requirements.txt` changes. To update, use this command:
|
||||
|
||||
```
|
||||
conda activate textgen
|
||||
cd text-generation-webui
|
||||
pip install -r requirements.txt --upgrade
|
||||
```
|
||||
## Downloading models
|
||||
|
||||
Models should be placed inside the `models` folder.
|
||||
@@ -196,6 +175,7 @@ Optionally, you can use the following command-line flags:
|
||||
| `-h`, `--help` | show this help message and exit |
|
||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||
| `--chat` | Launch the web UI in chat mode.|
|
||||
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
|
||||
| `--model MODEL` | Name of the model to load by default. |
|
||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||
| `--model-dir MODEL_DIR` | Path to directory with all the models |
|
||||
|
||||
@@ -17,7 +17,6 @@ def random_hash():
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
return ''.join(random.choice(letters) for i in range(9))
|
||||
|
||||
|
||||
async def run(context):
|
||||
server = "127.0.0.1"
|
||||
params = {
|
||||
@@ -37,12 +36,11 @@ async def run(context):
|
||||
'early_stopping': False,
|
||||
'seed': -1,
|
||||
}
|
||||
payload = json.dumps([context, params])
|
||||
session = random_hash()
|
||||
|
||||
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
|
||||
while content := json.loads(await websocket.recv()):
|
||||
# Python3.10 syntax, replace with if elif on older
|
||||
#Python3.10 syntax, replace with if elif on older
|
||||
match content["msg"]:
|
||||
case "send_hash":
|
||||
await websocket.send(json.dumps({
|
||||
@@ -56,7 +54,22 @@ async def run(context):
|
||||
"session_hash": session,
|
||||
"fn_index": 12,
|
||||
"data": [
|
||||
payload
|
||||
context,
|
||||
params['max_new_tokens'],
|
||||
params['do_sample'],
|
||||
params['temperature'],
|
||||
params['top_p'],
|
||||
params['typical_p'],
|
||||
params['repetition_penalty'],
|
||||
params['encoder_repetition_penalty'],
|
||||
params['top_k'],
|
||||
params['min_length'],
|
||||
params['no_repeat_ngram_size'],
|
||||
params['num_beams'],
|
||||
params['penalty_alpha'],
|
||||
params['length_penalty'],
|
||||
params['early_stopping'],
|
||||
params['seed'],
|
||||
]
|
||||
}))
|
||||
case "process_starts":
|
||||
@@ -70,7 +83,6 @@ async def run(context):
|
||||
|
||||
prompt = "What I would like to say is the following: "
|
||||
|
||||
|
||||
async def get_result():
|
||||
async for response in run(prompt):
|
||||
# Print intermediate steps
|
||||
|
||||
@@ -10,8 +10,6 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
|
||||
allowing you to use the API remotely.
|
||||
|
||||
'''
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# Server address
|
||||
@@ -40,11 +38,24 @@ params = {
|
||||
# Input prompt
|
||||
prompt = "What I would like to say is the following: "
|
||||
|
||||
payload = json.dumps([prompt, params])
|
||||
|
||||
response = requests.post(f"http://{server}:7860/run/textgen", json={
|
||||
"data": [
|
||||
payload
|
||||
prompt,
|
||||
params['max_new_tokens'],
|
||||
params['do_sample'],
|
||||
params['temperature'],
|
||||
params['top_p'],
|
||||
params['typical_p'],
|
||||
params['repetition_penalty'],
|
||||
params['encoder_repetition_penalty'],
|
||||
params['top_k'],
|
||||
params['min_length'],
|
||||
params['no_repeat_ngram_size'],
|
||||
params['num_beams'],
|
||||
params['penalty_alpha'],
|
||||
params['length_penalty'],
|
||||
params['early_stopping'],
|
||||
params['seed'],
|
||||
]
|
||||
}).json()
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
name: "Chiharu Yamada"
|
||||
context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
|
||||
greeting: |-
|
||||
*Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
name: "### Response:"
|
||||
your_name: "### Instruction:"
|
||||
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
@@ -1,3 +0,0 @@
|
||||
name: "<|assistant|>"
|
||||
your_name: "<|prompter|>"
|
||||
end_of_turn: "<|endoftext|>"
|
||||
@@ -1,3 +0,0 @@
|
||||
name: "### Assistant:"
|
||||
your_name: "### Human:"
|
||||
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
@@ -13,11 +13,10 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
@@ -32,22 +31,20 @@ def disable_torch_init():
|
||||
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def restore_torch_init():
|
||||
"""Rollback the change made by disable_torch_init."""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = Path(args.MODEL)
|
||||
model_name = path.name
|
||||
|
||||
print(f"Loading {model_name}...")
|
||||
# disable_torch_init()
|
||||
#disable_torch_init()
|
||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
# restore_torch_init()
|
||||
#restore_torch_init()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
|
||||
parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
|
||||
|
||||
@@ -64,15 +64,6 @@
|
||||
line-height: 1.428571429 !important;
|
||||
}
|
||||
|
||||
.message-body li {
|
||||
margin-top: 0.5em !important;
|
||||
margin-bottom: 0.5em !important;
|
||||
}
|
||||
|
||||
.message-body li > p {
|
||||
display: inline !important;
|
||||
}
|
||||
|
||||
.dark .message-body p em {
|
||||
color: rgb(138, 138, 138) !important;
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
.chat {
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 800px;
|
||||
height: 66.67vh;
|
||||
overflow-y: auto;
|
||||
padding-right: 20px;
|
||||
display: flex;
|
||||
flex-direction: column-reverse;
|
||||
}
|
||||
|
||||
.message {
|
||||
display: grid;
|
||||
grid-template-columns: 60px 1fr;
|
||||
padding-bottom: 25px;
|
||||
font-size: 15px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
line-height: 1.428571429;
|
||||
}
|
||||
|
||||
.username {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.message-body {}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
font-size: 15px !important;
|
||||
line-height: 1.428571429 !important;
|
||||
}
|
||||
|
||||
.message-body li {
|
||||
margin-top: 0.5em !important;
|
||||
margin-bottom: 0.5em !important;
|
||||
}
|
||||
|
||||
.message-body li > p {
|
||||
display: inline !important;
|
||||
}
|
||||
|
||||
.dark .message-body p em {
|
||||
color: rgb(138, 138, 138) !important;
|
||||
}
|
||||
|
||||
.message-body p em {
|
||||
color: rgb(110, 110, 110) !important;
|
||||
}
|
||||
|
||||
.gradio-container .chat .assistant-message {
|
||||
padding: 15px;
|
||||
border-radius: 20px;
|
||||
background-color: #0000000f;
|
||||
margin-bottom: 17.5px;
|
||||
}
|
||||
|
||||
.gradio-container .chat .user-message {
|
||||
padding: 15px;
|
||||
border-radius: 20px;
|
||||
margin-bottom: 17.5px !important;
|
||||
}
|
||||
|
||||
.dark .chat .assistant-message {
|
||||
background-color: #ffffff21;
|
||||
}
|
||||
@@ -41,7 +41,7 @@ ol li p, ul li p {
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab {
|
||||
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
|
||||
border: 0;
|
||||
}
|
||||
|
||||
@@ -63,7 +63,3 @@ span.math.inline {
|
||||
font-size: 27px;
|
||||
vertical-align: baseline !important;
|
||||
}
|
||||
|
||||
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
||||
flex-wrap: nowrap;
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
version: "3.3"
|
||||
services:
|
||||
text-generation-webui:
|
||||
build:
|
||||
context: .
|
||||
args:
|
||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
|
||||
GPTQ_VERSION: ${GPTQ_VERSION}
|
||||
WEBUI_VERSION: ${WEBUI_VERSION}
|
||||
env_file: .env
|
||||
ports:
|
||||
- "${HOST_PORT}:${CONTAINER_PORT}"
|
||||
- "${HOST_API_PORT}:${CONTAINER_API_PORT}"
|
||||
stdin_open: true
|
||||
tty: true
|
||||
volumes:
|
||||
- ./characters:/app/characters
|
||||
- ./extensions:/app/extensions
|
||||
- ./loras:/app/loras
|
||||
- ./models:/app/models
|
||||
- ./presets:/app/presets
|
||||
- ./prompts:/app/prompts
|
||||
- ./softprompts:/app/softprompts
|
||||
- ./training:/app/training
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [gpu]
|
||||
@@ -29,7 +29,6 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_file(url, output_folder):
|
||||
filename = Path(url.rsplit('/', 1)[1])
|
||||
output_path = output_folder / filename
|
||||
@@ -55,7 +54,6 @@ def get_file(url, output_folder):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
|
||||
|
||||
def sanitize_branch_name(branch_name):
|
||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
if pattern.match(branch_name):
|
||||
@@ -63,7 +61,6 @@ def sanitize_branch_name(branch_name):
|
||||
else:
|
||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||
|
||||
|
||||
def select_model_from_default_options():
|
||||
models = {
|
||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||
@@ -81,11 +78,11 @@ def select_model_from_default_options():
|
||||
choices = {}
|
||||
|
||||
print("Select the model that you want to download:\n")
|
||||
for i, name in enumerate(models):
|
||||
char = chr(ord('A') + i)
|
||||
for i,name in enumerate(models):
|
||||
char = chr(ord('A')+i)
|
||||
choices[char] = name
|
||||
print(f"{char}) {name}")
|
||||
char = chr(ord('A') + len(models))
|
||||
char = chr(ord('A')+len(models))
|
||||
print(f"{char}) None of the above")
|
||||
|
||||
print()
|
||||
@@ -109,7 +106,6 @@ EleutherAI/pythia-1.4b-deduped
|
||||
|
||||
return model, branch
|
||||
|
||||
|
||||
def get_download_links_from_huggingface(model, branch):
|
||||
base = "https://huggingface.co"
|
||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||
@@ -170,17 +166,15 @@ def get_download_links_from_huggingface(model, branch):
|
||||
|
||||
# If both pytorch and safetensors are available, download safetensors only
|
||||
if (has_pytorch or has_pt) and has_safetensors:
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
for i in range(len(classifications)-1, -1, -1):
|
||||
if classifications[i] in ['pytorch', 'pt']:
|
||||
links.pop(i)
|
||||
|
||||
return links, sha256, is_lora
|
||||
|
||||
|
||||
def download_files(file_list, output_folder, num_threads=8):
|
||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = args.MODEL
|
||||
branch = args.branch
|
||||
|
||||
@@ -9,7 +9,6 @@ params = {
|
||||
'port': 5000,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
@@ -33,7 +32,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
||||
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
|
||||
@@ -41,27 +40,24 @@ class Handler(BaseHTTPRequestHandler):
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
}
|
||||
|
||||
generator = generate_reply(
|
||||
prompt,
|
||||
generate_params,
|
||||
question = prompt,
|
||||
max_new_tokens = int(body.get('max_length', 200)),
|
||||
do_sample=bool(body.get('do_sample', True)),
|
||||
temperature=float(body.get('temperature', 0.5)),
|
||||
top_p=float(body.get('top_p', 1)),
|
||||
typical_p=float(body.get('typical', 1)),
|
||||
repetition_penalty=float(body.get('rep_pen', 1.1)),
|
||||
encoder_repetition_penalty=1,
|
||||
top_k=int(body.get('top_k', 0)),
|
||||
min_length=int(body.get('min_length', 0)),
|
||||
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
|
||||
num_beams=int(body.get('num_beams',1)),
|
||||
penalty_alpha=float(body.get('penalty_alpha', 0)),
|
||||
length_penalty=float(body.get('length_penalty', 1)),
|
||||
early_stopping=bool(body.get('early_stopping', False)),
|
||||
seed=int(body.get('seed', -1)),
|
||||
stopping_strings=body.get('stopping_strings', []),
|
||||
)
|
||||
|
||||
@@ -96,6 +92,5 @@ def run_server():
|
||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def setup():
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
|
||||
@@ -5,7 +5,6 @@ params = {
|
||||
"bias string": " *I am so happy*",
|
||||
}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@@ -14,7 +13,6 @@ def input_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -22,7 +20,6 @@ def output_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -30,12 +27,11 @@ def bot_prefix_modifier(string):
|
||||
behavior.
|
||||
"""
|
||||
|
||||
if params['activate']:
|
||||
if params['activate'] == True:
|
||||
return f'{string} {params["bias string"].strip()} '
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||
|
||||
@@ -22,8 +22,6 @@ if not shared.args.no_stream:
|
||||
raise ValueError
|
||||
|
||||
# Check if the API is valid and refresh the UI accordingly.
|
||||
|
||||
|
||||
def check_valid_api():
|
||||
|
||||
global user, user_info, params
|
||||
@@ -31,7 +29,7 @@ def check_valid_api():
|
||||
user = ElevenLabsUser(params['api_key'])
|
||||
user_info = user._get_subscription_data()
|
||||
print('checking api')
|
||||
if not params['activate']:
|
||||
if params['activate'] == False:
|
||||
return gr.update(value='Disconnected')
|
||||
elif user_info is None:
|
||||
print('Incorrect API Key')
|
||||
@@ -41,8 +39,6 @@ def check_valid_api():
|
||||
return gr.update(value='Connected')
|
||||
|
||||
# Once the API is verified, get the available voices and update the dropdown list
|
||||
|
||||
|
||||
def refresh_voices():
|
||||
|
||||
global user, user_info
|
||||
@@ -55,12 +51,10 @@ def refresh_voices():
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
@@ -70,7 +64,6 @@ def input_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -78,9 +71,9 @@ def output_modifier(string):
|
||||
|
||||
global params, wav_idx, user, user_info
|
||||
|
||||
if not params['activate']:
|
||||
if params['activate'] == False:
|
||||
return string
|
||||
elif user_info is None:
|
||||
elif user_info == None:
|
||||
return string
|
||||
|
||||
string = remove_surrounded_chars(string)
|
||||
@@ -101,7 +94,6 @@ def output_modifier(string):
|
||||
wav_idx += 1
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
|
||||
@@ -2,8 +2,9 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules.chat import load_character
|
||||
from modules.html_generator import get_image_cache
|
||||
from modules.shared import gradio
|
||||
from modules.shared import gradio, settings
|
||||
|
||||
|
||||
def generate_css():
|
||||
@@ -63,13 +64,22 @@ def generate_html():
|
||||
for file in sorted(Path("characters").glob("*")):
|
||||
if file.suffix in [".json", ".yml", ".yaml"]:
|
||||
character = file.stem
|
||||
container_html = '<div class="character-container">'
|
||||
container_html = f'<div class="character-container">'
|
||||
image_html = "<div class='placeholder'></div>"
|
||||
|
||||
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
||||
for i in [
|
||||
f"characters/{character}.png",
|
||||
f"characters/{character}.jpg",
|
||||
f"characters/{character}.jpeg",
|
||||
]:
|
||||
|
||||
path = Path(i)
|
||||
if path.exists():
|
||||
try:
|
||||
image_html = f'<img src="file/{get_image_cache(path)}">'
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
container_html += f'{image_html} <span class="character-name">{character}</span>'
|
||||
container_html += "</div>"
|
||||
@@ -85,7 +95,7 @@ def select_character(evt: gr.SelectData):
|
||||
def ui():
|
||||
with gr.Accordion("Character gallery", open=False):
|
||||
update = gr.Button("Refresh")
|
||||
gr.HTML(value="<style>" + generate_css() + "</style>")
|
||||
gr.HTML(value="<style>"+generate_css()+"</style>")
|
||||
gallery = gr.Dataset(components=[gr.HTML(visible=False)],
|
||||
label="",
|
||||
samples=generate_html(),
|
||||
|
||||
@@ -7,7 +7,6 @@ params = {
|
||||
|
||||
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@@ -16,7 +15,6 @@ def input_modifier(string):
|
||||
|
||||
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -24,7 +22,6 @@ def output_modifier(string):
|
||||
|
||||
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -34,7 +31,6 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Finding the language name from the language code to use as the default value
|
||||
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
||||
|
||||
@@ -4,14 +4,12 @@ import pandas as pd
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||
|
||||
|
||||
def get_prompt_by_name(name):
|
||||
if name == 'None':
|
||||
return ''
|
||||
else:
|
||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||
|
||||
|
||||
def ui():
|
||||
if not shared.is_chat():
|
||||
choices = ['None'] + list(df['Prompt name'])
|
||||
|
||||
51
extensions/prompt_template/script.py
Normal file
51
extensions/prompt_template/script.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared
|
||||
from modules import ui as _ui
|
||||
|
||||
params = {
|
||||
'template': '%input%'
|
||||
}
|
||||
|
||||
def get_available_templates():
|
||||
return ['None'] + sorted(set((k.stem for k in Path('extensions/prompt_template/templates').glob('*.txt'))), key=str.lower)
|
||||
|
||||
def load_template(fname):
|
||||
if fname in ['None', '']:
|
||||
return '%input%'
|
||||
else:
|
||||
with open(Path(f'extensions/prompt_template/templates/{fname}.txt'), 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
if text[-1] == '\n':
|
||||
text = text[:-1]
|
||||
return text
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
|
||||
return params['template'].replace('%input%', string)
|
||||
|
||||
def output_modifier(string):
|
||||
return f'\n{string}'
|
||||
|
||||
def setup():
|
||||
shared.args.verbose = True
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
template = gr.Textbox(value=params['template'], info="%input% will be replaced with your user input.", label='Template')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
template_menu = gr.Dropdown(choices=get_available_templates(), value='None', label='Available templates')
|
||||
_ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_templates()}, 'refresh-button')
|
||||
|
||||
template_menu.change(load_template, template_menu, template)
|
||||
template.change(lambda x: params.update({"template": x}), template, None)
|
||||
@@ -1,6 +1,5 @@
|
||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
### Instruction:
|
||||
Write a poem about the transformers Python library.
|
||||
Mention the word "large language models" in that poem.
|
||||
%input%
|
||||
### Response:
|
||||
|
||||
1
extensions/prompt_template/templates/Open Assistant.txt
Normal file
1
extensions/prompt_template/templates/Open Assistant.txt
Normal file
@@ -0,0 +1 @@
|
||||
<|prompter|>%input%<|endoftext|><|assistant|>
|
||||
@@ -30,15 +30,12 @@ streaming_state = shared.args.no_stream # remember if chat streaming was enable
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
pic_id = 0
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
@@ -65,8 +62,6 @@ def input_modifier(string):
|
||||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
|
||||
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params, pic_id
|
||||
@@ -88,7 +83,7 @@ def get_SD_pictures(description):
|
||||
|
||||
visible_result = ""
|
||||
for img_str in r['images']:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
|
||||
if params['save_img']:
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
||||
image.save(output_file.as_posix())
|
||||
@@ -106,8 +101,6 @@ def get_SD_pictures(description):
|
||||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
# and replace it with 'text' for the purposes of logging?
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -137,7 +130,6 @@ def output_modifier(string):
|
||||
shared.args.no_stream = streaming_state
|
||||
return image + "\n" + text
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -147,12 +139,10 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def force_pic():
|
||||
global picture_response
|
||||
picture_response = True
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
@@ -172,7 +162,7 @@ def ui():
|
||||
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
|
||||
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
|
||||
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
@@ -186,4 +176,4 @@ def ui():
|
||||
|
||||
force_btn.click(force_pic)
|
||||
generate_now_btn.click(force_pic)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
@@ -2,11 +2,12 @@ import base64
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
|
||||
from modules import chat, shared
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation with
|
||||
# custom input text given by 'value' in the format [text, visible_text]
|
||||
input_hijack = {
|
||||
@@ -17,13 +18,11 @@ input_hijack = {
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||
|
||||
|
||||
def caption_image(raw_image):
|
||||
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||
out = model.generate(**inputs, max_new_tokens=100)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
@@ -34,15 +33,16 @@ def generate_chat_picture(picture, name1, name2):
|
||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def ui():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
|
||||
# Prepare the hijack with custom inputs
|
||||
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
|
||||
|
||||
# Call the generation function
|
||||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear the picture from the upload field
|
||||
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)
|
||||
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@@ -17,11 +16,9 @@ from quant import make_quant
|
||||
|
||||
|
||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
@@ -36,42 +33,26 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
for name in exclude_layers:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
if 'groupsize' in gptq_args:
|
||||
make_quant_kwargs['groupsize'] = groupsize
|
||||
if 'faster' in gptq_args:
|
||||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
|
||||
make_quant(**make_quant_kwargs)
|
||||
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
||||
|
||||
del layers
|
||||
|
||||
print('Loading model ...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint), strict=False)
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_quantized(model_name):
|
||||
if not shared.args.model_type:
|
||||
# Try to determine model type from model name
|
||||
name = model_name.lower()
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
|
||||
if any((k in name for k in ['llama', 'alpaca'])):
|
||||
model_type = 'llama'
|
||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||
model_type = 'opt'
|
||||
@@ -84,18 +65,16 @@ def load_quantized(model_name):
|
||||
else:
|
||||
model_type = shared.args.model_type.lower()
|
||||
|
||||
if shared.args.pre_layer and model_type == 'llama':
|
||||
if model_type == 'llama' and shared.args.pre_layer:
|
||||
load_quant = llama_inference_offload.load_quant
|
||||
elif model_type in ('llama', 'opt', 'gptj'):
|
||||
if shared.args.pre_layer:
|
||||
print("Warning: ignoring --pre_layer because it only works for llama model type.")
|
||||
load_quant = _load_quant
|
||||
else:
|
||||
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||
exit()
|
||||
|
||||
# Now we are going to try to locate the quantized model file.
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
path_to_model = Path(f'models/{model_name}')
|
||||
found_pts = list(path_to_model.glob("*.pt"))
|
||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||
pt_path = None
|
||||
@@ -116,8 +95,8 @@ def load_quantized(model_name):
|
||||
else:
|
||||
pt_model = f'{model_name}-{shared.args.wbits}bit'
|
||||
|
||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
# Try to find the .safetensors or .pt both in models/ and in the subfolder
|
||||
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
if path.exists():
|
||||
print(f"Found {path}")
|
||||
pt_path = path
|
||||
@@ -128,7 +107,7 @@ def load_quantized(model_name):
|
||||
exit()
|
||||
|
||||
# qwopqwop200's offload
|
||||
if model_type == 'llama' and shared.args.pre_layer:
|
||||
if shared.args.pre_layer:
|
||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
|
||||
else:
|
||||
threshold = False if model_type == 'gptj' else 128
|
||||
@@ -136,7 +115,7 @@ def load_quantized(model_name):
|
||||
|
||||
# accelerate offload (doesn't work properly)
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
|
||||
@@ -13,7 +13,6 @@ def reload_model():
|
||||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def add_lora_to_model(lora_name):
|
||||
|
||||
# If a LoRA had been previously loaded, or if we want
|
||||
@@ -28,7 +27,7 @@ def add_lora_to_model(lora_name):
|
||||
if not shared.args.cpu:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
|
||||
elif shared.args.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
|
||||
|
||||
@@ -36,13 +36,13 @@ class RWKVModel:
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
|
||||
args = PIPELINE_ARGS(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban=token_ban, # ban the generation of some tokens
|
||||
token_stop=token_stop
|
||||
temperature = temperature,
|
||||
top_p = top_p,
|
||||
top_k = top_k,
|
||||
alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban = token_ban, # ban the generation of some tokens
|
||||
token_stop = token_stop
|
||||
)
|
||||
|
||||
return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||
@@ -54,7 +54,6 @@ class RWKVModel:
|
||||
reply += token
|
||||
yield reply
|
||||
|
||||
|
||||
class RWKVTokenizer:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
|
||||
def generate_reply_wrapper(string):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 50,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
}
|
||||
params = json.loads(string)
|
||||
for k in params[1]:
|
||||
generate_params[k] = params[1][k]
|
||||
for i in generate_reply(params[0], generate_params):
|
||||
yield i
|
||||
|
||||
|
||||
def create_apis():
|
||||
t1 = gr.Textbox(visible=False)
|
||||
t2 = gr.Textbox(visible=False)
|
||||
dummy = gr.Button(visible=False)
|
||||
|
||||
input_params = [t1]
|
||||
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
|
||||
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')
|
||||
@@ -30,7 +30,6 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
@@ -40,7 +39,6 @@ class Stream(transformers.StoppingCriteria):
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
|
||||
"""
|
||||
@@ -49,8 +47,8 @@ class Iteratorize:
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.mfunc=func
|
||||
self.c_callback=callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
@@ -82,7 +80,7 @@ class Iteratorize:
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True, None)
|
||||
obj = self.q.get(True,None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
@@ -98,7 +96,6 @@ class Iteratorize:
|
||||
self.stop_now = True
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
|
||||
240
modules/chat.py
240
modules/chat.py
@@ -12,54 +12,45 @@ from PIL import Image
|
||||
import modules.extensions as extensions_module
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import (fix_newlines, chat_html_wrapper,
|
||||
make_thumbnail)
|
||||
from modules.html_generator import fix_newlines, generate_chat_html
|
||||
from modules.text_generation import (encode, generate_reply,
|
||||
get_max_prompt_length)
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
|
||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
def generate_chat_output(history, name1, name2, character):
|
||||
if shared.args.cai_chat:
|
||||
return generate_chat_html(history, name1, name2, character)
|
||||
else:
|
||||
return history
|
||||
|
||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
|
||||
user_input = fix_newlines(user_input)
|
||||
rows = [f"{context.strip()}\n"]
|
||||
|
||||
# Finding the maximum prompt size
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
||||
|
||||
if is_instruct:
|
||||
prefix1 = f"{name1}\n"
|
||||
prefix2 = f"{name2}\n"
|
||||
else:
|
||||
prefix1 = f"{name1}: "
|
||||
prefix2 = f"{name2}: "
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
i = len(shared.history['internal'])-1
|
||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
string = shared.history['internal'][i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
||||
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
|
||||
prev_user_input = shared.history['internal'][i][0]
|
||||
if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
|
||||
i -= 1
|
||||
|
||||
if impersonate:
|
||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||
limit = 2
|
||||
else:
|
||||
# Adding the user message
|
||||
user_input = fix_newlines(user_input)
|
||||
if not impersonate:
|
||||
if len(user_input) > 0:
|
||||
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
rows.append(f"{name1}: {user_input}\n")
|
||||
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
|
||||
limit = 3
|
||||
else:
|
||||
rows.append(f"{name1}:")
|
||||
limit = 2
|
||||
|
||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
|
||||
prompt = ''.join(rows)
|
||||
|
||||
if also_return_rows:
|
||||
@@ -67,7 +58,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
next_character_found = False
|
||||
|
||||
@@ -87,25 +77,17 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
# is completed, trim it
|
||||
if not next_character_found:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
for j in range(len(string)-1, 0, -1):
|
||||
if reply[-j:] == string[:j]:
|
||||
reply = reply[:-j]
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
reply = fix_newlines(reply)
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
just_started = True
|
||||
eos_token = '\n' if stop_at_newline else None
|
||||
name1_original = name1
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
@@ -114,7 +96,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
visible_text = None
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
@@ -122,30 +104,32 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if shared.args.chat:
|
||||
visible_text = visible_text.replace('\n', '<br>')
|
||||
text = apply_extensions(text, "input")
|
||||
|
||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
yield shared.history['visible']+[[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
cumulative_reply = ''
|
||||
just_started = True
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
for i in range(chat_generation_attempts):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
if shared.args.chat:
|
||||
visible_reply = visible_reply.replace('\n', '<br>')
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
@@ -168,28 +152,23 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
yield shared.history['visible']
|
||||
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
|
||||
eos_token = '\n' if stop_at_newline else None
|
||||
|
||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
||||
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
|
||||
cumulative_reply = ''
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
for i in range(chat_generation_attempts):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
||||
reply = cumulative_reply + reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
|
||||
yield reply
|
||||
if next_character_found:
|
||||
break
|
||||
@@ -199,34 +178,36 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
|
||||
yield reply
|
||||
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
|
||||
for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
|
||||
yield generate_chat_html(history, name1, name2, shared.character)
|
||||
|
||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
yield chat_html_wrapper(history, name1, name2, mode)
|
||||
|
||||
|
||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
|
||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
|
||||
for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
else:
|
||||
shared.history['visible'][-1] = (last_visible[0], history[-1][1])
|
||||
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
def remove_last_message(name1, name2):
|
||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
last = shared.history['visible'].pop()
|
||||
shared.history['internal'].pop()
|
||||
else:
|
||||
last = ['', '']
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
||||
|
||||
if shared.args.cai_chat:
|
||||
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
|
||||
else:
|
||||
return shared.history['visible'], last[0]
|
||||
|
||||
def send_last_reply_to_input():
|
||||
if len(shared.history['internal']) > 0:
|
||||
@@ -234,20 +215,20 @@ def send_last_reply_to_input():
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def replace_last_reply(text, name1, name2, mode):
|
||||
def replace_last_reply(text, name1, name2):
|
||||
if len(shared.history['visible']) > 0:
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1][1] = text
|
||||
else:
|
||||
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
|
||||
def clear_html():
|
||||
return chat_html_wrapper([], "", "")
|
||||
return generate_chat_html([], "", "", shared.character)
|
||||
|
||||
|
||||
def clear_chat_log(name1, name2, greeting, mode):
|
||||
def clear_chat_log(name1, name2, greeting):
|
||||
shared.history['visible'] = []
|
||||
shared.history['internal'] = []
|
||||
|
||||
@@ -255,14 +236,12 @@ def clear_chat_log(name1, name2, greeting, mode):
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
|
||||
def redraw_html(name1, name2):
|
||||
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||
|
||||
def redraw_html(name1, name2, mode):
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
def tokenize_dialogue(dialogue, name1, name2):
|
||||
history = []
|
||||
|
||||
dialogue = re.sub('<START>', '', dialogue)
|
||||
@@ -274,8 +253,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
return history
|
||||
|
||||
messages = []
|
||||
for i in range(len(idx) - 1):
|
||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||
for i in range(len(idx)-1):
|
||||
messages.append(dialogue[idx[i]:idx[i+1]].strip())
|
||||
messages.append(dialogue[idx[-1]:].strip())
|
||||
|
||||
entry = ['', '']
|
||||
@@ -293,13 +272,12 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
for column in row:
|
||||
print("\n")
|
||||
for line in column.strip().split('\n'):
|
||||
print("| " + line + "\n")
|
||||
print("| "+line+"\n")
|
||||
print("|\n")
|
||||
print("------------------------------")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def save_history(timestamp=True):
|
||||
if timestamp:
|
||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
@@ -311,7 +289,6 @@ def save_history(timestamp=True):
|
||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||
return Path(f'logs/{fname}')
|
||||
|
||||
|
||||
def load_history(file, name1, name2):
|
||||
file = file.decode('utf-8')
|
||||
try:
|
||||
@@ -326,22 +303,20 @@ def load_history(file, name1, name2):
|
||||
elif 'chat' in j:
|
||||
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
shared.history['visible'][0][0] = ''
|
||||
else:
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
except:
|
||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
|
||||
|
||||
def replace_character_names(text, name1, name2):
|
||||
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
||||
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
||||
|
||||
|
||||
def build_pygmalion_style_context(data):
|
||||
context = ""
|
||||
if 'char_persona' in data and data['char_persona'] != '':
|
||||
@@ -351,37 +326,15 @@ def build_pygmalion_style_context(data):
|
||||
context = f"{context.strip()}\n<START>\n"
|
||||
return context
|
||||
|
||||
|
||||
def generate_pfp_cache(character):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
cache_folder.mkdir()
|
||||
|
||||
for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
|
||||
if path.exists():
|
||||
img = make_thumbnail(Image.open(path))
|
||||
img.save(Path('cache/pfp_character.png'), format='PNG')
|
||||
return img
|
||||
return None
|
||||
|
||||
|
||||
def load_character(character, name1, name2, mode):
|
||||
def load_character(character, name1, name2):
|
||||
shared.character = character
|
||||
shared.history['internal'] = []
|
||||
shared.history['visible'] = []
|
||||
context = greeting = end_of_turn = ""
|
||||
greeting_field = 'greeting'
|
||||
picture = None
|
||||
|
||||
# Deleting the profile picture cache, if any
|
||||
if Path("cache/pfp_character.png").exists():
|
||||
Path("cache/pfp_character.png").unlink()
|
||||
greeting = ""
|
||||
|
||||
if character != 'None':
|
||||
folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
|
||||
picture = generate_pfp_cache(character)
|
||||
for extension in ["yml", "yaml", "json"]:
|
||||
filepath = Path(f'{folder}/{character}.{extension}')
|
||||
filepath = Path(f'characters/{character}.{extension}')
|
||||
if filepath.exists():
|
||||
break
|
||||
file_contents = open(filepath, 'r', encoding='utf-8').read()
|
||||
@@ -397,21 +350,19 @@ def load_character(character, name1, name2, mode):
|
||||
|
||||
if 'context' in data:
|
||||
context = f"{data['context'].strip()}\n\n"
|
||||
elif "char_persona" in data:
|
||||
greeting_field = 'greeting'
|
||||
else:
|
||||
context = build_pygmalion_style_context(data)
|
||||
greeting_field = 'char_greeting'
|
||||
|
||||
if 'example_dialogue' in data:
|
||||
if 'example_dialogue' in data and data['example_dialogue'] != '':
|
||||
context += f"{data['example_dialogue'].strip()}\n"
|
||||
if greeting_field in data:
|
||||
if greeting_field in data and len(data[greeting_field].strip()) > 0:
|
||||
greeting = data[greeting_field]
|
||||
if 'end_of_turn' in data:
|
||||
end_of_turn = data['end_of_turn']
|
||||
else:
|
||||
context = shared.settings['context']
|
||||
name2 = shared.settings['name2']
|
||||
greeting = shared.settings['greeting']
|
||||
end_of_turn = shared.settings['end_of_turn']
|
||||
|
||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||
@@ -419,12 +370,13 @@ def load_character(character, name1, name2, mode):
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||
|
||||
if shared.args.cai_chat:
|
||||
return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
|
||||
else:
|
||||
return name1, name2, greeting, context, shared.history['visible']
|
||||
|
||||
def load_default_history(name1, name2):
|
||||
load_character("None", name1, name2, "chat")
|
||||
|
||||
load_character("None", name1, name2)
|
||||
|
||||
def upload_character(json_file, img, tavern=False):
|
||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||
@@ -444,7 +396,6 @@ def upload_character(json_file, img, tavern=False):
|
||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||
return outfile_name
|
||||
|
||||
|
||||
def upload_tavern_character(img, name1, name2):
|
||||
_img = Image.open(io.BytesIO(img))
|
||||
_img.getexif()
|
||||
@@ -453,18 +404,7 @@ def upload_tavern_character(img, name1, name2):
|
||||
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||
return upload_character(json.dumps(_json), img, tavern=True)
|
||||
|
||||
|
||||
def upload_your_profile_picture(img, name1, name2, mode):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
cache_folder.mkdir()
|
||||
|
||||
if img is None:
|
||||
if Path("cache/pfp_me.png").exists():
|
||||
Path("cache/pfp_me.png").unlink()
|
||||
else:
|
||||
img = make_thumbnail(img)
|
||||
img.save(Path('cache/pfp_me.png'))
|
||||
print('Profile picture saved to "cache/pfp_me.png"')
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||
def upload_your_profile_picture(img):
|
||||
img = Image.open(io.BytesIO(img))
|
||||
img.save(Path('img_me.png'))
|
||||
print('Profile picture saved to "img_me.png"')
|
||||
|
||||
@@ -9,7 +9,6 @@ state = {}
|
||||
available_extensions = []
|
||||
setup_called = set()
|
||||
|
||||
|
||||
def load_extensions():
|
||||
global state
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
@@ -24,16 +23,12 @@ def load_extensions():
|
||||
traceback.print_exc()
|
||||
|
||||
# This iterator returns the extensions in the order specified in the command-line
|
||||
|
||||
|
||||
def iterator():
|
||||
for name in sorted(state, key=lambda x: state[x][1]):
|
||||
for name in sorted(state, key=lambda x : state[x][1]):
|
||||
if state[name][0] == True:
|
||||
yield eval(f"extensions.{name}.script"), name
|
||||
|
||||
# Extension functions that map string -> string
|
||||
|
||||
|
||||
def apply_extensions(text, typ):
|
||||
for extension, _ in iterator():
|
||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||
@@ -44,7 +39,6 @@ def apply_extensions(text, typ):
|
||||
text = extension.bot_prefix_modifier(text)
|
||||
return text
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
|
||||
@@ -6,11 +6,10 @@ This is a library for formatting text outputs as nice HTML.
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import markdown
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image
|
||||
|
||||
# This is to store the paths to the thumbnails of the profile pictures
|
||||
image_cache = {}
|
||||
@@ -21,9 +20,6 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
|
||||
_4chan_css = css_f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
|
||||
cai_css = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
||||
instruct_css = f.read()
|
||||
|
||||
|
||||
def fix_newlines(string):
|
||||
string = string.replace('\n', '\n\n')
|
||||
@@ -32,8 +28,6 @@ def fix_newlines(string):
|
||||
return string
|
||||
|
||||
# This could probably be generalized and improved
|
||||
|
||||
|
||||
def convert_to_markdown(string):
|
||||
string = string.replace('\\begin{code}', '```')
|
||||
string = string.replace('\\end{code}', '```')
|
||||
@@ -43,13 +37,11 @@ def convert_to_markdown(string):
|
||||
string = fix_newlines(string)
|
||||
return markdown.markdown(string, extensions=['fenced_code'])
|
||||
|
||||
|
||||
def generate_basic_html(string):
|
||||
string = convert_to_markdown(string)
|
||||
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
||||
return string
|
||||
|
||||
|
||||
def process_post(post, c):
|
||||
t = post.split('\n')
|
||||
number = t[0].split(' ')[1]
|
||||
@@ -64,7 +56,6 @@ def process_post(post, c):
|
||||
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
||||
return src
|
||||
|
||||
|
||||
def generate_4chan_html(f):
|
||||
posts = []
|
||||
post = ''
|
||||
@@ -104,15 +95,6 @@ def generate_4chan_html(f):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def make_thumbnail(image):
|
||||
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
||||
if image.size[1] > 470:
|
||||
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_image_cache(path):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
@@ -120,56 +102,28 @@ def get_image_cache(path):
|
||||
|
||||
mtime = os.stat(path).st_mtime
|
||||
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
|
||||
img = make_thumbnail(Image.open(path))
|
||||
img = Image.open(path)
|
||||
img.thumbnail((200, 200))
|
||||
output_file = Path(f'cache/{path.name}_cache.png')
|
||||
img.convert('RGB').save(output_file, format='PNG')
|
||||
image_cache[path] = [mtime, output_file.as_posix()]
|
||||
|
||||
return image_cache[path][1]
|
||||
|
||||
def load_html_image(paths):
|
||||
for str_path in paths:
|
||||
path = Path(str_path)
|
||||
if path.exists():
|
||||
return f'<img src="file/{get_image_cache(path)}">'
|
||||
return ''
|
||||
|
||||
def generate_instruct_html(history):
|
||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
<div class="assistant-message">
|
||||
<div class="text">
|
||||
<div class="message-body">
|
||||
{row[1]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
continue
|
||||
|
||||
output += f"""
|
||||
<div class="user-message">
|
||||
<div class="text">
|
||||
<div class="message-body">
|
||||
{row[0]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
output += "</div>"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
def generate_chat_html(history, name1, name2, character):
|
||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||
|
||||
# The time.time() is to prevent the brower from caching the image
|
||||
suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
|
||||
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
||||
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
||||
img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
|
||||
img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
|
||||
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
@@ -209,18 +163,3 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
|
||||
output += "</div>"
|
||||
return output
|
||||
|
||||
|
||||
def generate_chat_html(history, name1, name2):
|
||||
return generate_cai_chat_html(history, name1, name2)
|
||||
|
||||
|
||||
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
||||
if mode == "cai-chat":
|
||||
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
||||
elif mode == "chat":
|
||||
return generate_chat_html(history, name1, name2)
|
||||
elif mode == "instruct":
|
||||
return generate_instruct_html(history)
|
||||
else:
|
||||
return ''
|
||||
|
||||
@@ -50,9 +50,9 @@ class LlamaCppModel:
|
||||
params.top_k = top_k
|
||||
params.temp = temperature
|
||||
params.repeat_penalty = repetition_penalty
|
||||
# params.repeat_last_n = repeat_last_n
|
||||
#params.repeat_last_n = repeat_last_n
|
||||
|
||||
# self.model.params = params
|
||||
#self.model.params = params
|
||||
self.model.add_bos()
|
||||
self.model.update_input(context)
|
||||
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
'''
|
||||
Based on
|
||||
https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
Documentation:
|
||||
https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
from modules import shared
|
||||
from modules.callbacks import Iteratorize
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, path):
|
||||
result = self()
|
||||
|
||||
params = {
|
||||
'model_path': str(path),
|
||||
'n_ctx': 2048,
|
||||
'seed': 0,
|
||||
'n_threads': shared.args.threads or None
|
||||
}
|
||||
self.model = Llama(**params)
|
||||
|
||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||
return result, result
|
||||
|
||||
def encode(self, string):
|
||||
if type(string) is str:
|
||||
string = string.encode()
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
||||
if type(context) is str:
|
||||
context = context.encode()
|
||||
tokens = self.model.tokenize(context)
|
||||
|
||||
output = b""
|
||||
count = 0
|
||||
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
||||
text = self.model.detokenize([token])
|
||||
output += text
|
||||
if callback:
|
||||
callback(text.decode())
|
||||
|
||||
count += 1
|
||||
if count >= token_count or (token == self.model.token_eos()):
|
||||
break
|
||||
|
||||
return output.decode()
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
yield reply
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
import transformers
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
BitsAndBytesConfig)
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
@@ -42,7 +42,7 @@ def load_model(model_name):
|
||||
t0 = time.time()
|
||||
|
||||
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
|
||||
shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
|
||||
|
||||
# Default settings
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
||||
@@ -103,9 +103,9 @@ def load_model(model_name):
|
||||
|
||||
# llamacpp model
|
||||
elif shared.is_llamacpp:
|
||||
from modules.llamacpp_model_alternative import LlamaCppModel
|
||||
from modules.llamacpp_model import LlamaCppModel
|
||||
|
||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
|
||||
model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
|
||||
print(f"llama.cpp weights detected: {model_file}\n")
|
||||
|
||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||
@@ -132,7 +132,7 @@ def load_model(model_name):
|
||||
params["torch_dtype"] = torch.float16
|
||||
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
@@ -140,11 +140,11 @@ def load_model(model_name):
|
||||
max_memory['cpu'] = max_cpu_memory
|
||||
params['max_memory'] = max_memory
|
||||
elif shared.args.auto_devices:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
|
||||
suggestion = round((total_mem-1000) / 1000) * 1000
|
||||
if total_mem - suggestion < 800:
|
||||
suggestion -= 1000
|
||||
suggestion = int(round(suggestion / 1000))
|
||||
suggestion = int(round(suggestion/1000))
|
||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||
|
||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
||||
@@ -164,7 +164,7 @@ def load_model(model_name):
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes=model._no_split_modules
|
||||
no_split_module_classes = model._no_split_modules
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
@@ -172,8 +172,6 @@ def load_model(model_name):
|
||||
# Loading the tokenizer
|
||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||
tokenizer.truncation_side = 'left'
|
||||
@@ -181,7 +179,6 @@ def load_model(model_name):
|
||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_soft_prompt(name):
|
||||
if name == 'None':
|
||||
shared.soft_prompt = False
|
||||
|
||||
@@ -33,7 +33,6 @@ settings = {
|
||||
'name2': 'Assistant',
|
||||
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
||||
'greeting': 'Hello there!',
|
||||
'end_of_turn': '',
|
||||
'stop_at_newline': False,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_prompt_size_min': 0,
|
||||
@@ -45,7 +44,6 @@ settings = {
|
||||
'chat_default_extensions': ["gallery"],
|
||||
'presets': {
|
||||
'default': 'NovelAI-Sphinx Moth',
|
||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||
'.*pygmalion': 'NovelAI-Storywriter',
|
||||
'.*RWKV': 'Naive',
|
||||
},
|
||||
@@ -61,7 +59,6 @@ settings = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
@@ -72,13 +69,12 @@ def str2bool(v):
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
|
||||
# Basic settings
|
||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
|
||||
parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
|
||||
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
|
||||
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
|
||||
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
|
||||
@@ -135,18 +131,12 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Deprecation warnings for parameters that have been renamed
|
||||
# Provisional, this will be deleted later
|
||||
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
|
||||
for k in deprecated_dict:
|
||||
if eval(f"args.{k}") != deprecated_dict[k][1]:
|
||||
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
|
||||
exec(f"args.{deprecated_dict[k][0]} = args.{k}")
|
||||
|
||||
# Deprecation warnings for parameters that have been removed
|
||||
if args.cai_chat:
|
||||
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
||||
args.chat = True
|
||||
|
||||
|
||||
def is_chat():
|
||||
return args.chat
|
||||
return any((args.chat, args.cai_chat))
|
||||
|
||||
@@ -16,12 +16,11 @@ from modules.models import local_rank
|
||||
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
max_length = 2048 - tokens
|
||||
max_length = 2048-tokens
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
return max_length
|
||||
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
@@ -29,10 +28,6 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
return input_ids
|
||||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
if shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.flexgen:
|
||||
@@ -45,7 +40,6 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
else:
|
||||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||
@@ -55,17 +49,14 @@ def decode(output_ids):
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
#filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
|
||||
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
@@ -74,8 +65,6 @@ def fix_gpt4chan(s):
|
||||
return s
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
|
||||
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
@@ -86,7 +75,6 @@ def fix_galactica(s):
|
||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if not shared.is_chat():
|
||||
if 'galactica' in model_name.lower():
|
||||
@@ -100,29 +88,24 @@ def formatted_outputs(reply, model_name):
|
||||
else:
|
||||
return reply
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
if seed != -1:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
set_manual_seed(generate_state['seed'])
|
||||
set_manual_seed(seed)
|
||||
shared.stop_everything = False
|
||||
generate_params = {}
|
||||
t0 = time.time()
|
||||
|
||||
original_question = question
|
||||
@@ -134,13 +117,10 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params["token_count"] = generate_state["max_new_tokens"]
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
|
||||
output = original_question+reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
@@ -150,8 +130,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
|
||||
output = original_question+reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
@@ -165,7 +145,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||
input_ids = encode(question, max_new_tokens)
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
@@ -178,21 +158,33 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||
|
||||
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
|
||||
generate_params = {}
|
||||
if not shared.args.flexgen:
|
||||
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params["eos_token_id"] = eos_token_ids
|
||||
generate_params["stopping_criteria"] = stopping_criteria_list
|
||||
if shared.args.no_stream:
|
||||
generate_params["min_length"] = 0
|
||||
generate_params.update({
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"eos_token_id": eos_token_ids,
|
||||
"stopping_criteria": stopping_criteria_list,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"typical_p": typical_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"encoder_repetition_penalty": encoder_repetition_penalty,
|
||||
"top_k": top_k,
|
||||
"min_length": min_length if shared.args.no_stream else 0,
|
||||
"no_repeat_ngram_size": no_repeat_ngram_size,
|
||||
"num_beams": num_beams,
|
||||
"penalty_alpha": penalty_alpha,
|
||||
"length_penalty": length_penalty,
|
||||
"early_stopping": early_stopping,
|
||||
})
|
||||
else:
|
||||
for k in ["do_sample", "temperature"]:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params["stop"] = generate_state["eos_token_ids"][-1]
|
||||
if not shared.args.no_stream:
|
||||
generate_params["max_new_tokens"] = 8
|
||||
|
||||
generate_params.update({
|
||||
"max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"stop": eos_token_ids[-1],
|
||||
})
|
||||
if shared.args.no_cache:
|
||||
generate_params.update({"use_cache": False})
|
||||
if shared.args.deepspeed:
|
||||
@@ -252,7 +244,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
||||
for i in range(max_new_tokens//8+1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
@@ -283,6 +275,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - original_tokens
|
||||
new_tokens = len(output)-original_tokens
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||
return
|
||||
|
||||
@@ -19,10 +19,8 @@ CURRENT_STEPS = 0
|
||||
MAX_STEPS = 0
|
||||
CURRENT_GRADIENT_ACCUM = 1
|
||||
|
||||
|
||||
def get_dataset(path: str, ext: str):
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||
|
||||
return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
@@ -46,35 +44,29 @@ def create_train_interface():
|
||||
with gr.Tab(label="Formatted Dataset"):
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
|
||||
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
with gr.Tab(label="Raw Text File"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
|
||||
with gr.Row():
|
||||
start_button = gr.Button("Start LoRA Training")
|
||||
stop_button = gr.Button("Interrupt")
|
||||
|
||||
output = gr.Markdown(value="Ready")
|
||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
||||
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
|
||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||
|
||||
|
||||
def do_interrupt():
|
||||
global WANT_INTERRUPT
|
||||
WANT_INTERRUPT = True
|
||||
|
||||
|
||||
class Callbacks(transformers.TrainerCallback):
|
||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS, MAX_STEPS
|
||||
@@ -83,7 +75,6 @@ class Callbacks(transformers.TrainerCallback):
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS
|
||||
CURRENT_STEPS += 1
|
||||
@@ -91,7 +82,6 @@ class Callbacks(transformers.TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
|
||||
def clean_path(base_path: str, path: str):
|
||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||
@@ -101,9 +91,8 @@ def clean_path(base_path: str, path: str):
|
||||
return path
|
||||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
|
||||
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
|
||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||
WANT_INTERRUPT = False
|
||||
CURRENT_STEPS = 0
|
||||
@@ -114,25 +103,6 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
||||
actual_lr = float(learning_rate)
|
||||
|
||||
model_type = type(shared.model).__name__
|
||||
if model_type != "LlamaForCausalLM":
|
||||
if model_type == "PeftModelForCausalLM":
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
|
||||
time.sleep(5)
|
||||
|
||||
if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
|
||||
yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
|
||||
return
|
||||
|
||||
elif not shared.args.load_in_8bit:
|
||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
|
||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||
yield "Cannot input zeroes."
|
||||
return
|
||||
@@ -156,20 +126,15 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
raw_text = file.read()
|
||||
tokens = shared.tokenizer.encode(raw_text)
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
|
||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||
for i in range(1, len(tokens)):
|
||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
||||
del tokens
|
||||
|
||||
if newline_favor_len > 0:
|
||||
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||
|
||||
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
del text_chunks
|
||||
train_data = train_data.shuffle()
|
||||
data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||
train_data = data.shuffle()
|
||||
eval_data = None
|
||||
del text_chunks
|
||||
|
||||
else:
|
||||
if dataset in ['None', '']:
|
||||
@@ -215,7 +180,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
# TODO: Should target_modules be configurable?
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
target_modules=[ "q_proj", "v_proj" ],
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
@@ -267,37 +232,33 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
# TODO: save/load checkpoints to resume from?
|
||||
print("Starting training...")
|
||||
yield "Starting..."
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupted before start."
|
||||
return
|
||||
|
||||
def threaded_run():
|
||||
def threadedRun():
|
||||
trainer.train()
|
||||
|
||||
thread = threading.Thread(target=threaded_run)
|
||||
thread = threading.Thread(target=threadedRun)
|
||||
thread.start()
|
||||
last_step = 0
|
||||
start_time = time.perf_counter()
|
||||
lastStep = 0
|
||||
startTime = time.perf_counter()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(0.5)
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
||||
|
||||
elif CURRENT_STEPS != last_step:
|
||||
last_step = CURRENT_STEPS
|
||||
time_elapsed = time.perf_counter() - start_time
|
||||
if time_elapsed <= 0:
|
||||
timer_info = ""
|
||||
total_time_estimate = 999
|
||||
elif CURRENT_STEPS != lastStep:
|
||||
lastStep = CURRENT_STEPS
|
||||
timeElapsed = time.perf_counter() - startTime
|
||||
if timeElapsed <= 0:
|
||||
timerInfo = ""
|
||||
totalTimeEstimate = 999
|
||||
else:
|
||||
its = CURRENT_STEPS / time_elapsed
|
||||
its = CURRENT_STEPS / timeElapsed
|
||||
if its > 1:
|
||||
timer_info = f"`{its:.2f}` it/s"
|
||||
timerInfo = f"`{its:.2f}` it/s"
|
||||
else:
|
||||
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||
total_time_estimate = (1.0 / its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
timerInfo = f"`{1.0/its:.2f}` s/it"
|
||||
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
||||
|
||||
print("Training complete, saving...")
|
||||
lora_model.save_pretrained(lora_name)
|
||||
@@ -309,31 +270,6 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
print("Training complete!")
|
||||
yield f"Done! LoRA saved to `{lora_name}`"
|
||||
|
||||
|
||||
def split_chunks(arr, step):
|
||||
for i in range(0, len(arr), step):
|
||||
yield arr[i:i + step]
|
||||
|
||||
|
||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
first_newline = chunk.index('\n')
|
||||
if first_newline < max_length:
|
||||
chunk = chunk[first_newline + 1:]
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
last_newline = chunk.rindex('\n')
|
||||
if len(chunk) - last_newline < max_length:
|
||||
chunk = chunk[:last_newline]
|
||||
return chunk
|
||||
|
||||
|
||||
def format_time(seconds: float):
|
||||
if seconds < 120:
|
||||
return f"`{seconds:.0f}` seconds"
|
||||
minutes = seconds / 60
|
||||
if minutes < 120:
|
||||
return f"`{minutes:.0f}` minutes"
|
||||
hours = minutes / 60
|
||||
return f"`{hours:.0f}` hours"
|
||||
|
||||
@@ -13,7 +13,6 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
||||
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||
chat_js = f.read()
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
@@ -23,7 +22,6 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
do_sample=True
|
||||
top_p=0.1
|
||||
top_k=40
|
||||
temperature=0.7
|
||||
repetition_penalty=1.18
|
||||
typical_p=1.0
|
||||
@@ -1 +0,0 @@
|
||||
<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>
|
||||
@@ -3,11 +3,12 @@ bitsandbytes==0.37.2
|
||||
datasets
|
||||
flexgen==0.1.7
|
||||
gradio==3.24.1
|
||||
llamacpp==0.1.11
|
||||
markdown
|
||||
numpy
|
||||
peft==0.2.0
|
||||
requests
|
||||
rwkv==0.7.3
|
||||
rwkv==0.7.2
|
||||
safetensors==0.3.0
|
||||
sentencepiece
|
||||
pyyaml
|
||||
|
||||
215
server.py
215
server.py
@@ -1,7 +1,3 @@
|
||||
import os
|
||||
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
@@ -12,11 +8,10 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
from modules import chat, shared, training, ui, api
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
from modules import chat, shared, training, ui
|
||||
from modules.html_generator import generate_chat_html
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt
|
||||
from modules.text_generation import (clear_torch_cache, generate_reply,
|
||||
@@ -34,18 +29,15 @@ if settings_file is not None:
|
||||
for item in new_settings:
|
||||
shared.settings[item] = new_settings[item]
|
||||
|
||||
|
||||
def get_available_models():
|
||||
if shared.args.flexgen:
|
||||
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
||||
else:
|
||||
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
|
||||
def get_available_presets():
|
||||
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
|
||||
|
||||
|
||||
def get_available_prompts():
|
||||
prompts = []
|
||||
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
||||
@@ -53,37 +45,23 @@ def get_available_prompts():
|
||||
prompts += ['None']
|
||||
return prompts
|
||||
|
||||
|
||||
def get_available_characters():
|
||||
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
||||
|
||||
|
||||
def get_available_instruction_templates():
|
||||
path = "characters/instruction-following"
|
||||
paths = []
|
||||
if os.path.exists(path):
|
||||
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
||||
|
||||
|
||||
def get_available_extensions():
|
||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||
|
||||
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||
|
||||
def get_available_softprompts():
|
||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||
|
||||
|
||||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def load_model_wrapper(selected_model):
|
||||
if selected_model != shared.model_name:
|
||||
shared.model_name = selected_model
|
||||
@@ -94,13 +72,11 @@ def load_model_wrapper(selected_model):
|
||||
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_lora_wrapper(selected_lora):
|
||||
add_lora_to_model(selected_lora)
|
||||
return selected_lora
|
||||
|
||||
|
||||
def load_preset_values(preset_menu, state, return_dict=False):
|
||||
def load_preset_values(preset_menu, return_dict=False):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
@@ -122,14 +98,13 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
||||
i = i.rstrip(',').strip().split('=')
|
||||
if len(i) == 2 and i[0].strip() != 'tokens':
|
||||
generate_params[i[0].strip()] = eval(i[1].strip())
|
||||
|
||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
@@ -143,14 +118,23 @@ def upload_soft_prompt(file):
|
||||
|
||||
return name
|
||||
|
||||
def create_model_and_preset_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
|
||||
def save_prompt(text):
|
||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
|
||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
return f"Saved to prompts/{fname}"
|
||||
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
return ''
|
||||
@@ -161,13 +145,12 @@ def load_prompt(fname):
|
||||
text = text[:-1]
|
||||
return text
|
||||
|
||||
|
||||
def create_prompt_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Column():
|
||||
@@ -177,33 +160,12 @@ def create_prompt_menus():
|
||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
|
||||
def create_model_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||
generate_params[k] = shared.settings[k]
|
||||
shared.gradio['generate_state'] = gr.State(generate_params)
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
|
||||
create_model_and_preset_menus()
|
||||
with gr.Column():
|
||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||
|
||||
@@ -214,12 +176,12 @@ def create_settings_menus(default_preset):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
@@ -227,6 +189,7 @@ def create_settings_menus(default_preset):
|
||||
with gr.Box():
|
||||
gr.Markdown('Contrastive search')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
|
||||
with gr.Box():
|
||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||
with gr.Row():
|
||||
@@ -236,24 +199,30 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
|
||||
|
||||
gr.Markdown('Upload a soft prompt (.zip format):')
|
||||
with gr.Row():
|
||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
|
||||
|
||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||
cmd_list = vars(shared.args)
|
||||
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||
#int_list = [k for k in cmd_list if type(k) is int]
|
||||
|
||||
shared.args.extensions = extensions
|
||||
for k in modes[1:]:
|
||||
@@ -268,7 +237,6 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||
|
||||
shared.need_restart = True
|
||||
|
||||
|
||||
available_models = get_available_models()
|
||||
available_presets = get_available_presets()
|
||||
available_characters = get_available_characters()
|
||||
@@ -302,7 +270,7 @@ else:
|
||||
for i, model in enumerate(available_models):
|
||||
print(f'{i+1}. {model}')
|
||||
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
|
||||
i = int(input()) - 1
|
||||
i = int(input())-1
|
||||
print()
|
||||
shared.model_name = available_models[i]
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
@@ -315,19 +283,22 @@ if shared.lora_name != "None":
|
||||
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
|
||||
else:
|
||||
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
||||
title = 'Text generation web UI'
|
||||
|
||||
title ='Text generation web UI'
|
||||
|
||||
def create_interface():
|
||||
|
||||
gen_events = []
|
||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||
extensions_module.load_extensions()
|
||||
|
||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
if shared.is_chat():
|
||||
shared.gradio['Chat input'] = gr.State()
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||
if shared.args.cai_chat:
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
|
||||
else:
|
||||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
@@ -344,23 +315,14 @@ def create_interface():
|
||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
|
||||
|
||||
with gr.Tab("Character", elem_id="chat-settings"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=8):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
|
||||
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
|
||||
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
|
||||
with gr.Column(scale=1):
|
||||
shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
|
||||
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
|
||||
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Tab('Chat history'):
|
||||
@@ -385,6 +347,8 @@ def create_interface():
|
||||
|
||||
gr.Markdown("# TavernAI PNG format")
|
||||
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
||||
with gr.Tab('Upload your profile picture'):
|
||||
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
|
||||
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
with gr.Box():
|
||||
@@ -395,61 +359,60 @@ def create_interface():
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
|
||||
|
||||
def set_chat_input(textbox):
|
||||
return textbox, ""
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear history with confirmation
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
|
||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
|
||||
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
|
||||
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||
|
||||
# Clearing stuff and saving the history
|
||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
|
||||
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
|
||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
@@ -479,9 +442,9 @@ def create_interface():
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
@@ -512,17 +475,14 @@ def create_interface():
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
with gr.Tab("Model", elem_id="model-tab"):
|
||||
create_model_menus()
|
||||
|
||||
with gr.Tab("Training", elem_id="training-tab"):
|
||||
training.create_train_interface()
|
||||
|
||||
@@ -536,6 +496,7 @@ def create_interface():
|
||||
cmd_list = vars(shared.args)
|
||||
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
||||
#int_list = [k for k in cmd_list if type(k) is int]
|
||||
|
||||
gr.Markdown("*Experimental*")
|
||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
||||
@@ -544,26 +505,11 @@ def create_interface():
|
||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
|
||||
|
||||
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
|
||||
shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
def change_dict_value(d, key, value):
|
||||
d[key] = value
|
||||
return d
|
||||
|
||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
||||
if k not in shared.gradio:
|
||||
continue
|
||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
||||
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||
else:
|
||||
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
||||
|
||||
if not shared.is_chat():
|
||||
api.create_apis()
|
||||
|
||||
# Authentication
|
||||
auth = None
|
||||
if shared.args.gradio_auth_path is not None:
|
||||
@@ -580,7 +526,6 @@ def create_interface():
|
||||
else:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
|
||||
|
||||
|
||||
create_interface()
|
||||
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user