Compare commits
116 Commits
state_as_f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49ce866c99 | ||
|
|
ff610b47d2 | ||
|
|
3850f13624 | ||
|
|
461ca7faf5 | ||
|
|
832ee4323d | ||
|
|
1405cd8af2 | ||
|
|
2289d3686f | ||
|
|
61641a4551 | ||
|
|
f2be87235d | ||
|
|
8265d45db8 | ||
|
|
37d52c96bc | ||
|
|
f2ec880e81 | ||
|
|
f34f2daa3d | ||
|
|
cacbcda208 | ||
|
|
749c08a4ff | ||
|
|
e9e93189ff | ||
|
|
dc3c9d00a0 | ||
|
|
457d3c58eb | ||
|
|
78bbc66fc4 | ||
|
|
0f212093a3 | ||
|
|
64f5c90ee7 | ||
|
|
58b34c0841 | ||
|
|
5234071c04 | ||
|
|
09d8119e3c | ||
|
|
0caf718a21 | ||
|
|
85a7954823 | ||
|
|
d37b4f76b1 | ||
|
|
bd04ff27ad | ||
|
|
f035b01823 | ||
|
|
b7ca89ba3f | ||
|
|
52339e9b20 | ||
|
|
4961f43702 | ||
|
|
617530296e | ||
|
|
0f1627eff1 | ||
|
|
d679c4be13 | ||
|
|
45244ed125 | ||
|
|
7e70741a4e | ||
|
|
11b23db8d4 | ||
|
|
2c14df81a8 | ||
|
|
c6e9ba20a4 | ||
|
|
843f672227 | ||
|
|
769aa900ea | ||
|
|
32d078487e | ||
|
|
30befe492a | ||
|
|
1911504f82 | ||
|
|
8178fde2cb | ||
|
|
dba2000d2b | ||
|
|
65552d2157 | ||
|
|
8c6155251a | ||
|
|
992663fa20 | ||
|
|
625d81f495 | ||
|
|
57f768eaad | ||
|
|
a3085dba07 | ||
|
|
120f5662cf | ||
|
|
b27d757fd1 | ||
|
|
d29f4624e9 | ||
|
|
170e0c05c4 | ||
|
|
34ec02d41d | ||
|
|
f91d3a3ff4 | ||
|
|
ebdf4c8c12 | ||
|
|
7436dd5b4a | ||
|
|
bce1b7fbb2 | ||
|
|
f7860ce192 | ||
|
|
ece8ed2c84 | ||
|
|
cc693a7546 | ||
|
|
2fde50a800 | ||
|
|
acc235aced | ||
|
|
df561fd896 | ||
|
|
d272ac46dd | ||
|
|
cb169d0834 | ||
|
|
2f16d0afca | ||
|
|
a6a00cb82f | ||
|
|
c97c270040 | ||
|
|
0b458bf82d | ||
|
|
ffd102e5c0 | ||
|
|
5543a5089d | ||
|
|
1dc464dcb0 | ||
|
|
962e33dc10 | ||
|
|
42ea6a3fc0 | ||
|
|
e563b015d8 | ||
|
|
1c413ed593 | ||
|
|
3f922d4bfb | ||
|
|
744bf7cbf2 | ||
|
|
768354239b | ||
|
|
6762e62a40 | ||
|
|
a453d4e9c4 | ||
|
|
ec979cd9c4 | ||
|
|
2c0018d946 | ||
|
|
8fa182cfa7 | ||
|
|
862aad637b | ||
|
|
46c4654226 | ||
|
|
ea6e77df72 | ||
|
|
848c4edfd5 | ||
|
|
e047cd1def | ||
|
|
08b9d1b23a | ||
|
|
64bcde56ab | ||
|
|
58ed87e5d9 | ||
|
|
21be80242e | ||
|
|
310bf46a94 | ||
|
|
20b8ca4482 | ||
|
|
113f94b61e | ||
|
|
5f4f38ca5d | ||
|
|
d9e7aba714 | ||
|
|
59058576b5 | ||
|
|
eec3665845 | ||
|
|
03cb44fc8c | ||
|
|
39f3fec913 | ||
|
|
8cd899515e | ||
|
|
4a28f39823 | ||
|
|
158ec51ae3 | ||
|
|
0c7ef26981 | ||
|
|
5b301d9a02 | ||
|
|
4a400320dd | ||
|
|
e94ab5dac1 | ||
|
|
641646a801 | ||
|
|
3f3e42e26c |
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
.env
|
||||||
|
Dockerfile
|
||||||
|
/characters
|
||||||
|
/loras
|
||||||
|
/models
|
||||||
|
/presets
|
||||||
|
/prompts
|
||||||
|
/softprompts
|
||||||
|
/training
|
||||||
25
.env.example
Normal file
25
.env.example
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# by default the Dockerfile specifies these versions: 3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX
|
||||||
|
# however for me to work i had to specify the exact version for my card ( 2060 ) it was 7.5
|
||||||
|
# https://developer.nvidia.com/cuda-gpus you can find the version for your card here
|
||||||
|
TORCH_CUDA_ARCH_LIST=7.5
|
||||||
|
|
||||||
|
# these commands worked for me with roughly 4.5GB of vram
|
||||||
|
CLI_ARGS=--model llama-7b-4bit --wbits 4 --listen --auto-devices
|
||||||
|
|
||||||
|
# the following examples have been tested with the files linked in docs/README_docker.md:
|
||||||
|
# example running 13b with 4bit/128 groupsize : CLI_ARGS=--model llama-13b-4bit-128g --wbits 4 --listen --groupsize 128 --pre_layer 25
|
||||||
|
# example with loading api extension and public share: CLI_ARGS=--model llama-7b-4bit --wbits 4 --listen --auto-devices --no-stream --extensions api --share
|
||||||
|
# example running 7b with 8bit groupsize : CLI_ARGS=--model llama-7b --load-in-8bit --listen --auto-devices
|
||||||
|
|
||||||
|
# the port the webui binds to on the host
|
||||||
|
HOST_PORT=7860
|
||||||
|
# the port the webui binds to inside the container
|
||||||
|
CONTAINER_PORT=7860
|
||||||
|
|
||||||
|
# the port the api binds to on the host
|
||||||
|
HOST_API_PORT=5000
|
||||||
|
# the port the api binds to inside the container
|
||||||
|
CONTAINER_API_PORT=5000
|
||||||
|
|
||||||
|
# the version used to install text-generation-webui from
|
||||||
|
WEBUI_VERSION=HEAD
|
||||||
68
Dockerfile
Normal file
68
Dockerfile
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as builder
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install --no-install-recommends -y git vim build-essential python3-dev python3-venv && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa /build
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
RUN python3 -m venv /build/venv
|
||||||
|
RUN . /build/venv/bin/activate && \
|
||||||
|
pip3 install --upgrade pip setuptools && \
|
||||||
|
pip3 install torch torchvision torchaudio && \
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
# https://developer.nvidia.com/cuda-gpus
|
||||||
|
# for a rtx 2060: ARG TORCH_CUDA_ARCH_LIST="7.5"
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
|
||||||
|
RUN . /build/venv/bin/activate && \
|
||||||
|
python3 setup_cuda.py bdist_wheel -d .
|
||||||
|
|
||||||
|
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
|
||||||
|
|
||||||
|
LABEL maintainer="Your Name <your.email@example.com>"
|
||||||
|
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install --no-install-recommends -y git python3 python3-pip make g++ && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
|
||||||
|
RUN mkdir /app
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ARG WEBUI_VERSION
|
||||||
|
RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Using provided webui source"
|
||||||
|
|
||||||
|
RUN virtualenv /app/venv
|
||||||
|
RUN . /app/venv/bin/activate && \
|
||||||
|
pip3 install --upgrade pip setuptools && \
|
||||||
|
pip3 install torch torchvision torchaudio
|
||||||
|
|
||||||
|
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
|
||||||
|
RUN . /app/venv/bin/activate && \
|
||||||
|
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
|
||||||
|
|
||||||
|
COPY extensions/api/requirements.txt /app/extensions/api/requirements.txt
|
||||||
|
COPY extensions/elevenlabs_tts/requirements.txt /app/extensions/elevenlabs_tts/requirements.txt
|
||||||
|
COPY extensions/google_translate/requirements.txt /app/extensions/google_translate/requirements.txt
|
||||||
|
COPY extensions/silero_tts/requirements.txt /app/extensions/silero_tts/requirements.txt
|
||||||
|
COPY extensions/whisper_stt/requirements.txt /app/extensions/whisper_stt/requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
COPY requirements.txt /app/requirements.txt
|
||||||
|
RUN . /app/venv/bin/activate && \
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
ENV CLI_ARGS=""
|
||||||
|
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}
|
||||||
62
README.md
62
README.md
@@ -1,11 +1,9 @@
|
|||||||
# Text generation web UI
|
# Text generation web UI
|
||||||
|
|
||||||
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, OPT, and GALACTICA.
|
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.
|
||||||
|
|
||||||
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
||||||
|
|
||||||
[[Try it on Google Colab]](https://colab.research.google.com/github/oobabooga/AI-Notebooks/blob/main/Colab-TextGen-GPU.ipynb)
|
|
||||||
|
|
||||||
| |  |
|
| |  |
|
||||||
|:---:|:---:|
|
|:---:|:---:|
|
||||||
| |  |
|
| |  |
|
||||||
@@ -15,6 +13,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* Dropdown menu for switching between models
|
* Dropdown menu for switching between models
|
||||||
* Notebook mode that resembles OpenAI's playground
|
* Notebook mode that resembles OpenAI's playground
|
||||||
* Chat mode for conversation and role playing
|
* Chat mode for conversation and role playing
|
||||||
|
* Instruct mode compatible with Alpaca, Vicuna, and Open Assistant formats **\*NEW!\***
|
||||||
* Nice HTML output for GPT-4chan
|
* Nice HTML output for GPT-4chan
|
||||||
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
||||||
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
||||||
@@ -30,10 +29,9 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
|
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
|
||||||
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
|
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
|
||||||
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
|
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
|
||||||
* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
||||||
* Softprompts
|
* Softprompts
|
||||||
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
||||||
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -72,9 +70,15 @@ On Linux or WSL, it can be automatically installed with these two commands:
|
|||||||
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
|
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
|
||||||
bash Miniconda3.sh
|
bash Miniconda3.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Source: https://educe-ubc.github.io/conda.html
|
Source: https://educe-ubc.github.io/conda.html
|
||||||
|
|
||||||
|
#### 0.1 (Ubuntu/WSL) Install build tools
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo apt install build-essential
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### 1. Create a new conda environment
|
#### 1. Create a new conda environment
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -116,8 +120,26 @@ As an alternative to the recommended WSL method, you can install the web UI nati
|
|||||||
|
|
||||||
### Alternative: Docker
|
### Alternative: Docker
|
||||||
|
|
||||||
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
|
```
|
||||||
|
cp .env.example .env
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
|
||||||
|
|
||||||
|
You need to have docker compose v2.17 or higher installed in your system. For installation instructions, see [Docker compose installation](https://github.com/oobabooga/text-generation-webui/wiki/Docker-compose-installation).
|
||||||
|
|
||||||
|
Contributed by [@loeken](https://github.com/loeken) in [#633](https://github.com/oobabooga/text-generation-webui/pull/633)
|
||||||
|
|
||||||
|
### Updating the requirements
|
||||||
|
|
||||||
|
From time to time, the `requirements.txt` changes. To update, use this command:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda activate textgen
|
||||||
|
cd text-generation-webui
|
||||||
|
pip install -r requirements.txt --upgrade
|
||||||
|
```
|
||||||
## Downloading models
|
## Downloading models
|
||||||
|
|
||||||
Models should be placed inside the `models` folder.
|
Models should be placed inside the `models` folder.
|
||||||
@@ -173,14 +195,14 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### Basic settings
|
#### Basic settings
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|--------------------------------------------|-------------|
|
||||||
| `-h`, `--help` | show this help message and exit |
|
| `-h`, `--help` | Show this help message and exit. |
|
||||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||||
| `--chat` | Launch the web UI in chat mode. |
|
| `--chat` | Launch the web UI in chat mode. |
|
||||||
| `--model MODEL` | Name of the model to load by default. |
|
| `--model MODEL` | Name of the model to load by default. |
|
||||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||||
| `--model-dir MODEL_DIR` | Path to directory with all the models |
|
| `--model-dir MODEL_DIR` | Path to directory with all the models. |
|
||||||
| `--lora-dir LORA_DIR` | Path to directory with all the loras |
|
| `--lora-dir LORA_DIR` | Path to directory with all the loras. |
|
||||||
| `--no-stream` | Don't stream the text output in real time. |
|
| `--no-stream` | Don't stream the text output in real time. |
|
||||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag. |
|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag. |
|
||||||
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
||||||
@@ -189,8 +211,8 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### Accelerate/transformers
|
#### Accelerate/transformers
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------------------|-------------|
|
||||||
| `--cpu` | Use the CPU to generate text.|
|
| `--cpu` | Use the CPU to generate text. Warning: Training on CPU is extremely slow.|
|
||||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. |
|
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. |
|
||||||
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
||||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
||||||
@@ -199,17 +221,19 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||||
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
||||||
|
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
|
||||||
|
| `--sdp-attention` | Use torch 2.0's sdp attention. |
|
||||||
|
|
||||||
#### llama.cpp
|
#### llama.cpp
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|-------------|-------------|
|
||||||
| `--threads` | Number of threads to use in llama.cpp. |
|
| `--threads` | Number of threads to use in llama.cpp. |
|
||||||
|
|
||||||
#### GPTQ
|
#### GPTQ
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------|-------------|
|
||||||
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
||||||
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||||
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
||||||
@@ -227,7 +251,7 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### DeepSpeed
|
#### DeepSpeed
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------------|-------------|
|
||||||
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
||||||
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
||||||
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
||||||
@@ -235,14 +259,14 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### RWKV
|
#### RWKV
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------|-------------|
|
||||||
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
||||||
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
||||||
|
|
||||||
#### Gradio
|
#### Gradio
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------------|-------------|
|
||||||
| `--listen` | Make the web UI reachable from your local network. |
|
| `--listen` | Make the web UI reachable from your local network. |
|
||||||
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
||||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||||
@@ -267,6 +291,8 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
|
|||||||
|
|
||||||
Pull requests, suggestions, and issue reports are welcome.
|
Pull requests, suggestions, and issue reports are welcome.
|
||||||
|
|
||||||
|
You are also welcome to review open pull requests.
|
||||||
|
|
||||||
Before reporting a bug, make sure that you have:
|
Before reporting a bug, make sure that you have:
|
||||||
|
|
||||||
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
|
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
|
||||||
|
|||||||
@@ -12,11 +12,17 @@ import string
|
|||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
|
# Note, Gradio may pick a different fn value as the definition of the Gradio app changes.
|
||||||
|
# You can always launch the web UI and inspect the websocket stream using your browser's dev tools
|
||||||
|
# to determine what value Gradio expects here.
|
||||||
|
GRADIO_FN = 29
|
||||||
|
|
||||||
|
|
||||||
def random_hash():
|
def random_hash():
|
||||||
letters = string.ascii_lowercase + string.digits
|
letters = string.ascii_lowercase + string.digits
|
||||||
return ''.join(random.choice(letters) for i in range(9))
|
return ''.join(random.choice(letters) for i in range(9))
|
||||||
|
|
||||||
|
|
||||||
async def run(context):
|
async def run(context):
|
||||||
server = "127.0.0.1"
|
server = "127.0.0.1"
|
||||||
params = {
|
params = {
|
||||||
@@ -35,6 +41,10 @@ async def run(context):
|
|||||||
'length_penalty': 1,
|
'length_penalty': 1,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'seed': -1,
|
'seed': -1,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'custom_stopping_strings': [],
|
||||||
|
'ban_eos_token': False
|
||||||
}
|
}
|
||||||
payload = json.dumps([context, params])
|
payload = json.dumps([context, params])
|
||||||
session = random_hash()
|
session = random_hash()
|
||||||
@@ -46,14 +56,14 @@ async def run(context):
|
|||||||
case "send_hash":
|
case "send_hash":
|
||||||
await websocket.send(json.dumps({
|
await websocket.send(json.dumps({
|
||||||
"session_hash": session,
|
"session_hash": session,
|
||||||
"fn_index": 12
|
"fn_index": GRADIO_FN
|
||||||
}))
|
}))
|
||||||
case "estimation":
|
case "estimation":
|
||||||
pass
|
pass
|
||||||
case "send_data":
|
case "send_data":
|
||||||
await websocket.send(json.dumps({
|
await websocket.send(json.dumps({
|
||||||
"session_hash": session,
|
"session_hash": session,
|
||||||
"fn_index": 12,
|
"fn_index": GRADIO_FN,
|
||||||
"data": [
|
"data": [
|
||||||
payload
|
payload
|
||||||
]
|
]
|
||||||
@@ -69,6 +79,7 @@ async def run(context):
|
|||||||
|
|
||||||
prompt = "What I would like to say is the following: "
|
prompt = "What I would like to say is the following: "
|
||||||
|
|
||||||
|
|
||||||
async def get_result():
|
async def get_result():
|
||||||
async for response in run(prompt):
|
async for response in run(prompt):
|
||||||
# Print intermediate steps
|
# Print intermediate steps
|
||||||
|
|||||||
@@ -22,10 +22,10 @@ server = "127.0.0.1"
|
|||||||
params = {
|
params = {
|
||||||
'max_new_tokens': 200,
|
'max_new_tokens': 200,
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'temperature': 0.5,
|
'temperature': 0.72,
|
||||||
'top_p': 0.9,
|
'top_p': 0.73,
|
||||||
'typical_p': 1,
|
'typical_p': 1,
|
||||||
'repetition_penalty': 1.05,
|
'repetition_penalty': 1.1,
|
||||||
'encoder_repetition_penalty': 1.0,
|
'encoder_repetition_penalty': 1.0,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
@@ -35,6 +35,10 @@ params = {
|
|||||||
'length_penalty': 1,
|
'length_penalty': 1,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'seed': -1,
|
'seed': -1,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'custom_stopping_strings': [],
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'ban_eos_token': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Input prompt
|
# Input prompt
|
||||||
|
|||||||
3
characters/instruction-following/Vicuna.yaml
Normal file
3
characters/instruction-following/Vicuna.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
name: "### Assistant:"
|
||||||
|
your_name: "### Human:"
|
||||||
|
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||||
@@ -17,6 +17,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
|
|||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def disable_torch_init():
|
def disable_torch_init():
|
||||||
"""
|
"""
|
||||||
Disable the redundant torch default initialization to accelerate model creation.
|
Disable the redundant torch default initialization to accelerate model creation.
|
||||||
@@ -31,12 +32,14 @@ def disable_torch_init():
|
|||||||
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
||||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||||
|
|
||||||
|
|
||||||
def restore_torch_init():
|
def restore_torch_init():
|
||||||
"""Rollback the change made by disable_torch_init."""
|
"""Rollback the change made by disable_torch_init."""
|
||||||
import torch
|
import torch
|
||||||
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
|
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
|
||||||
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
|
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
path = Path(args.MODEL)
|
path = Path(args.MODEL)
|
||||||
model_name = path.name
|
model_name = path.name
|
||||||
|
|||||||
@@ -36,3 +36,8 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
|||||||
.wrap.svelte-6roggh.svelte-6roggh {
|
.wrap.svelte-6roggh.svelte-6roggh {
|
||||||
max-height: 92.5%;
|
max-height: 92.5%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* This is for the microphone button in the whisper extension */
|
||||||
|
.sm.svelte-1ipelgc {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,11 +7,13 @@
|
|||||||
padding-right: 20px;
|
padding-right: 20px;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
|
word-break: break-word;
|
||||||
|
overflow-wrap: anywhere;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: 60px 1fr;
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
padding-bottom: 25px;
|
padding-bottom: 25px;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
font-family: Helvetica, Arial, sans-serif;
|
font-family: Helvetica, Arial, sans-serif;
|
||||||
@@ -64,6 +66,22 @@
|
|||||||
line-height: 1.428571429 !important;
|
line-height: 1.428571429 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body li {
|
||||||
|
margin-top: 0.5em !important;
|
||||||
|
margin-bottom: 0.5em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body li > p {
|
||||||
|
display: inline !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body code {
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
.message-body :not(pre) > code {
|
||||||
|
white-space: normal !important;
|
||||||
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
color: rgb(138, 138, 138) !important;
|
color: rgb(138, 138, 138) !important;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
padding-right: 20px;
|
padding-right: 20px;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
|
word-break: break-word;
|
||||||
|
overflow-wrap: anywhere;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
@@ -18,10 +20,6 @@
|
|||||||
line-height: 1.428571429;
|
line-height: 1.428571429;
|
||||||
}
|
}
|
||||||
|
|
||||||
.text p {
|
|
||||||
margin-top: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.username {
|
.username {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
@@ -29,9 +27,23 @@
|
|||||||
.message-body {}
|
.message-body {}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p {
|
||||||
margin-bottom: 0 !important;
|
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
line-height: 1.428571429 !important;
|
}
|
||||||
|
|
||||||
|
.message-body li {
|
||||||
|
margin-top: 0.5em !important;
|
||||||
|
margin-bottom: 0.5em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body li > p {
|
||||||
|
display: inline !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.message-body code {
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
.message-body :not(pre) > code {
|
||||||
|
white-space: normal !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
@@ -42,15 +54,20 @@
|
|||||||
color: rgb(110, 110, 110) !important;
|
color: rgb(110, 110, 110) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.assistant-message {
|
.gradio-container .chat .assistant-message {
|
||||||
padding: 10px;
|
padding: 15px;
|
||||||
|
border-radius: 20px;
|
||||||
|
background-color: #0000000f;
|
||||||
|
margin-top: 9px !important;
|
||||||
|
margin-bottom: 18px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.user-message {
|
.gradio-container .chat .user-message {
|
||||||
padding: 10px;
|
padding: 15px;
|
||||||
background-color: #f1f1f1;
|
border-radius: 20px;
|
||||||
|
margin-bottom: 9px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .user-message {
|
.dark .chat .assistant-message {
|
||||||
background-color: #ffffff1a;
|
background-color: #374151;
|
||||||
}
|
}
|
||||||
12
css/main.css
12
css/main.css
@@ -41,7 +41,7 @@ ol li p, ul li p {
|
|||||||
display: inline-block;
|
display: inline-block;
|
||||||
}
|
}
|
||||||
|
|
||||||
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
|
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab {
|
||||||
border: 0;
|
border: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,3 +67,13 @@ span.math.inline {
|
|||||||
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
||||||
flex-wrap: nowrap;
|
flex-wrap: nowrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.header_bar {
|
||||||
|
background-color: #f7f7f7;
|
||||||
|
margin-bottom: 40px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .header_bar {
|
||||||
|
border: none !important;
|
||||||
|
background-color: #8080802b;
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
|
document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar");
|
||||||
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
|
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
|
||||||
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
|
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
|
||||||
|
|
||||||
|
|||||||
31
docker-compose.yml
Normal file
31
docker-compose.yml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
version: "3.3"
|
||||||
|
services:
|
||||||
|
text-generation-webui:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
args:
|
||||||
|
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
||||||
|
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
|
||||||
|
WEBUI_VERSION: ${WEBUI_VERSION}
|
||||||
|
env_file: .env
|
||||||
|
ports:
|
||||||
|
- "${HOST_PORT}:${CONTAINER_PORT}"
|
||||||
|
- "${HOST_API_PORT}:${CONTAINER_API_PORT}"
|
||||||
|
stdin_open: true
|
||||||
|
tty: true
|
||||||
|
volumes:
|
||||||
|
- ./characters:/app/characters
|
||||||
|
- ./extensions:/app/extensions
|
||||||
|
- ./loras:/app/loras
|
||||||
|
- ./models:/app/models
|
||||||
|
- ./presets:/app/presets
|
||||||
|
- ./prompts:/app/prompts
|
||||||
|
- ./softprompts:/app/softprompts
|
||||||
|
- ./training:/app/training
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
device_ids: ['0']
|
||||||
|
capabilities: [gpu]
|
||||||
@@ -19,47 +19,6 @@ import requests
|
|||||||
import tqdm
|
import tqdm
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
|
||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
|
||||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
|
||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def get_file(url, output_folder):
|
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
|
||||||
output_path = output_folder / filename
|
|
||||||
if output_path.exists() and not args.clean:
|
|
||||||
# Check if the file has already been downloaded completely
|
|
||||||
r = requests.get(url, stream=True)
|
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
|
||||||
if output_path.stat().st_size >= total_size:
|
|
||||||
return
|
|
||||||
# Otherwise, resume the download from where it left off
|
|
||||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
|
||||||
mode = 'ab'
|
|
||||||
else:
|
|
||||||
headers = {}
|
|
||||||
mode = 'wb'
|
|
||||||
|
|
||||||
r = requests.get(url, stream=True, headers=headers)
|
|
||||||
with open(output_path, mode) as f:
|
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
|
||||||
block_size = 1024
|
|
||||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
|
||||||
for data in r.iter_content(block_size):
|
|
||||||
t.update(len(data))
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
def sanitize_branch_name(branch_name):
|
|
||||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
||||||
if pattern.match(branch_name):
|
|
||||||
return branch_name
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
|
||||||
|
|
||||||
def select_model_from_default_options():
|
def select_model_from_default_options():
|
||||||
models = {
|
models = {
|
||||||
@@ -106,7 +65,21 @@ EleutherAI/pythia-1.4b-deduped
|
|||||||
|
|
||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
def get_download_links_from_huggingface(model, branch):
|
|
||||||
|
def sanitize_model_and_branch_names(model, branch):
|
||||||
|
if model[-1] == '/':
|
||||||
|
model = model[:-1]
|
||||||
|
if branch is None:
|
||||||
|
branch = "main"
|
||||||
|
else:
|
||||||
|
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||||
|
if not pattern.match(branch):
|
||||||
|
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||||
|
|
||||||
|
return model, branch
|
||||||
|
|
||||||
|
|
||||||
|
def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||||
base = "https://huggingface.co"
|
base = "https://huggingface.co"
|
||||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||||
cursor = b""
|
cursor = b""
|
||||||
@@ -138,14 +111,14 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||||
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||||
|
|
||||||
if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
|
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
|
||||||
if 'lfs' in dict[i]:
|
if 'lfs' in dict[i]:
|
||||||
sha256.append([fname, dict[i]['lfs']['oid']])
|
sha256.append([fname, dict[i]['lfs']['oid']])
|
||||||
if is_text:
|
if is_text:
|
||||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||||
classifications.append('text')
|
classifications.append('text')
|
||||||
continue
|
continue
|
||||||
if not args.text_only:
|
if not text_only:
|
||||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
has_safetensors = True
|
has_safetensors = True
|
||||||
@@ -172,40 +145,68 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
|
|
||||||
return links, sha256, is_lora
|
return links, sha256, is_lora
|
||||||
|
|
||||||
def download_files(file_list, output_folder, num_threads=8):
|
|
||||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def get_output_folder(model, branch, is_lora, base_folder=None):
|
||||||
model = args.MODEL
|
if base_folder is None:
|
||||||
branch = args.branch
|
|
||||||
if model is None:
|
|
||||||
model, branch = select_model_from_default_options()
|
|
||||||
else:
|
|
||||||
if model[-1] == '/':
|
|
||||||
model = model[:-1]
|
|
||||||
branch = args.branch
|
|
||||||
if branch is None:
|
|
||||||
branch = "main"
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
branch = sanitize_branch_name(branch)
|
|
||||||
except ValueError as err_branch:
|
|
||||||
print(f"Error: {err_branch}")
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
|
|
||||||
|
|
||||||
if args.output is not None:
|
|
||||||
base_folder = args.output
|
|
||||||
else:
|
|
||||||
base_folder = 'models' if not is_lora else 'loras'
|
base_folder = 'models' if not is_lora else 'loras'
|
||||||
|
|
||||||
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
||||||
if branch != 'main':
|
if branch != 'main':
|
||||||
output_folder += f'_{branch}'
|
output_folder += f'_{branch}'
|
||||||
output_folder = Path(base_folder) / output_folder
|
output_folder = Path(base_folder) / output_folder
|
||||||
|
return output_folder
|
||||||
|
|
||||||
if args.check:
|
|
||||||
|
def get_single_file(url, output_folder, start_from_scratch=False):
|
||||||
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
|
output_path = output_folder / filename
|
||||||
|
if output_path.exists() and not start_from_scratch:
|
||||||
|
# Check if the file has already been downloaded completely
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
if output_path.stat().st_size >= total_size:
|
||||||
|
return
|
||||||
|
# Otherwise, resume the download from where it left off
|
||||||
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
|
mode = 'ab'
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
|
mode = 'wb'
|
||||||
|
|
||||||
|
r = requests.get(url, stream=True, headers=headers)
|
||||||
|
with open(output_path, mode) as f:
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
block_size = 1024
|
||||||
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
|
for data in r.iter_content(block_size):
|
||||||
|
t.update(len(data))
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
|
||||||
|
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
|
||||||
|
# Creating the folder and writing the metadata
|
||||||
|
if not output_folder.exists():
|
||||||
|
output_folder.mkdir()
|
||||||
|
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
||||||
|
f.write(f'url: https://huggingface.co/{model}\n')
|
||||||
|
f.write(f'branch: {branch}\n')
|
||||||
|
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
||||||
|
sha256_str = ''
|
||||||
|
for i in range(len(sha256)):
|
||||||
|
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
||||||
|
if sha256_str != '':
|
||||||
|
f.write(f'sha256sum:\n{sha256_str}')
|
||||||
|
|
||||||
|
# Downloading the files
|
||||||
|
print(f"Downloading the model to {output_folder}")
|
||||||
|
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_files(model, branch, links, sha256, output_folder):
|
||||||
# Validate the checksums
|
# Validate the checksums
|
||||||
validated = True
|
validated = True
|
||||||
for i in range(len(sha256)):
|
for i in range(len(sha256)):
|
||||||
@@ -230,21 +231,40 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||||
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
|
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||||
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
branch = args.branch
|
||||||
|
model = args.MODEL
|
||||||
|
if model is None:
|
||||||
|
model, branch = select_model_from_default_options()
|
||||||
|
|
||||||
|
# Cleaning up the model/branch names
|
||||||
|
try:
|
||||||
|
model, branch = sanitize_model_and_branch_names(model, branch)
|
||||||
|
except ValueError as err_branch:
|
||||||
|
print(f"Error: {err_branch}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# Getting the download links from Hugging Face
|
||||||
|
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
|
||||||
|
|
||||||
|
# Getting the output folder
|
||||||
|
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
|
||||||
|
|
||||||
|
if args.check:
|
||||||
|
# Check previously downloaded files
|
||||||
|
check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
|
# Download files
|
||||||
# Creating the folder and writing the metadata
|
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
|
||||||
if not output_folder.exists():
|
|
||||||
output_folder.mkdir()
|
|
||||||
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
|
||||||
f.write(f'url: https://huggingface.co/{model}\n')
|
|
||||||
f.write(f'branch: {branch}\n')
|
|
||||||
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
|
||||||
sha256_str = ''
|
|
||||||
for i in range(len(sha256)):
|
|
||||||
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
|
||||||
if sha256_str != '':
|
|
||||||
f.write(f'sha256sum:\n{sha256_str}')
|
|
||||||
|
|
||||||
# Downloading the files
|
|
||||||
print(f"Downloading the model to {output_folder}")
|
|
||||||
download_files(links, output_folder, args.threads)
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ params = {
|
|||||||
'port': 5000,
|
'port': 5000,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path == '/api/v1/model':
|
if self.path == '/api/v1/model':
|
||||||
@@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
prompt = body['prompt']
|
prompt = body['prompt']
|
||||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||||
|
|
||||||
max_context = body.get('max_context_length', 2048)
|
max_context = body.get('max_context_length', 2048)
|
||||||
|
|
||||||
@@ -56,12 +57,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'length_penalty': float(body.get('length_penalty', 1)),
|
'length_penalty': float(body.get('length_penalty', 1)),
|
||||||
'early_stopping': bool(body.get('early_stopping', False)),
|
'early_stopping': bool(body.get('early_stopping', False)),
|
||||||
'seed': int(body.get('seed', -1)),
|
'seed': int(body.get('seed', -1)),
|
||||||
|
'add_bos_token': int(body.get('add_bos_token', True)),
|
||||||
|
'custom_stopping_strings': body.get('custom_stopping_strings', []),
|
||||||
|
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||||
|
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||||
}
|
}
|
||||||
|
|
||||||
generator = generate_reply(
|
generator = generate_reply(
|
||||||
prompt,
|
prompt,
|
||||||
generate_params,
|
generate_params,
|
||||||
stopping_strings=body.get('stopping_strings', []),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
@@ -77,6 +81,19 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
elif self.path == '/api/v1/token-count':
|
||||||
|
# Not compatible with KoboldAI api
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
tokens = encode(body['prompt'])[0]
|
||||||
|
response = json.dumps({
|
||||||
|
'results': [{
|
||||||
|
'tokens': len(tokens)
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
@@ -95,5 +112,6 @@ def run_server():
|
|||||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
Thread(target=run_server, daemon=True).start()
|
Thread(target=run_server, daemon=True).start()
|
||||||
|
|||||||
@@ -1,42 +1,82 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import os
|
||||||
|
|
||||||
|
# get the current directory of the script
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# check if the bias_options.txt file exists, if not, create it
|
||||||
|
bias_file = os.path.join(current_dir, "bias_options.txt")
|
||||||
|
if not os.path.isfile(bias_file):
|
||||||
|
with open(bias_file, "w") as f:
|
||||||
|
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
|
||||||
|
|
||||||
|
# read bias options from the text file
|
||||||
|
with open(bias_file, "r") as f:
|
||||||
|
bias_options = [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"activate": True,
|
"activate": True,
|
||||||
"bias string": " *I am so happy*",
|
"bias string": " *I am so happy*",
|
||||||
|
"use custom string": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
they are fed into the model.
|
they are fed into the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
the prefix text for the Bot and can be used to bias its
|
the prefix text for the Bot and can be used to bias its
|
||||||
behavior.
|
behavior.
|
||||||
"""
|
"""
|
||||||
|
if params['activate']:
|
||||||
if params['activate'] == True:
|
if params['use custom string']:
|
||||||
|
return f'{string} {params["custom string"].strip()} '
|
||||||
|
else:
|
||||||
return f'{string} {params["bias string"].strip()} '
|
return f'{string} {params["bias string"].strip()} '
|
||||||
else:
|
else:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||||
string = gr.Textbox(value=params["bias string"], label='Character bias')
|
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
|
||||||
|
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
|
||||||
|
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
string.change(lambda x: params.update({"bias string": x}), string, None)
|
def update_bias_string(x):
|
||||||
|
if x:
|
||||||
|
params.update({"bias string": x})
|
||||||
|
else:
|
||||||
|
params.update({"bias string": dropdown_string.get()})
|
||||||
|
return x
|
||||||
|
|
||||||
|
def update_custom_string(x):
|
||||||
|
params.update({"custom string": x})
|
||||||
|
|
||||||
|
dropdown_string.change(update_bias_string, dropdown_string, None)
|
||||||
|
custom_string.change(update_custom_string, custom_string, None)
|
||||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||||
|
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
|
||||||
|
|
||||||
|
# Group elements together depending on the selected option
|
||||||
|
def bias_string_group():
|
||||||
|
if use_custom_string.value:
|
||||||
|
return gr.Group([use_custom_string, custom_string])
|
||||||
|
else:
|
||||||
|
return dropdown_string
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.shared as shared
|
|
||||||
from elevenlabslib import ElevenLabsUser
|
from elevenlabslib import ElevenLabsUser
|
||||||
from elevenlabslib.helpers import save_bytes_to_path
|
from elevenlabslib.helpers import save_bytes_to_path
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'activate': True,
|
'activate': True,
|
||||||
'api_key': '12345',
|
'api_key': '12345',
|
||||||
@@ -22,6 +23,8 @@ if not shared.args.no_stream:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
# Check if the API is valid and refresh the UI accordingly.
|
# Check if the API is valid and refresh the UI accordingly.
|
||||||
|
|
||||||
|
|
||||||
def check_valid_api():
|
def check_valid_api():
|
||||||
|
|
||||||
global user, user_info, params
|
global user, user_info, params
|
||||||
@@ -29,7 +32,7 @@ def check_valid_api():
|
|||||||
user = ElevenLabsUser(params['api_key'])
|
user = ElevenLabsUser(params['api_key'])
|
||||||
user_info = user._get_subscription_data()
|
user_info = user._get_subscription_data()
|
||||||
print('checking api')
|
print('checking api')
|
||||||
if params['activate'] == False:
|
if not params['activate']:
|
||||||
return gr.update(value='Disconnected')
|
return gr.update(value='Disconnected')
|
||||||
elif user_info is None:
|
elif user_info is None:
|
||||||
print('Incorrect API Key')
|
print('Incorrect API Key')
|
||||||
@@ -39,6 +42,8 @@ def check_valid_api():
|
|||||||
return gr.update(value='Connected')
|
return gr.update(value='Connected')
|
||||||
|
|
||||||
# Once the API is verified, get the available voices and update the dropdown list
|
# Once the API is verified, get the available voices and update the dropdown list
|
||||||
|
|
||||||
|
|
||||||
def refresh_voices():
|
def refresh_voices():
|
||||||
|
|
||||||
global user, user_info
|
global user, user_info
|
||||||
@@ -51,11 +56,13 @@ def refresh_voices():
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@@ -64,6 +71,7 @@ def input_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@@ -71,9 +79,9 @@ def output_modifier(string):
|
|||||||
|
|
||||||
global params, wav_idx, user, user_info
|
global params, wav_idx, user, user_info
|
||||||
|
|
||||||
if params['activate'] == False:
|
if not params['activate']:
|
||||||
return string
|
return string
|
||||||
elif user_info == None:
|
elif user_info is None:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
string = remove_surrounded_chars(string)
|
string = remove_surrounded_chars(string)
|
||||||
@@ -94,6 +102,7 @@ def output_modifier(string):
|
|||||||
wav_idx += 1
|
wav_idx += 1
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ params = {
|
|||||||
|
|
||||||
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
@@ -15,6 +16,7 @@ def input_modifier(string):
|
|||||||
|
|
||||||
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@@ -22,6 +24,7 @@ def output_modifier(string):
|
|||||||
|
|
||||||
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
@@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
# Finding the language name from the language code to use as the default value
|
# Finding the language name from the language code to use as the default value
|
||||||
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.shared as shared
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_by_name(name):
|
def get_prompt_by_name(name):
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
return ''
|
return ''
|
||||||
else:
|
else:
|
||||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
choices = ['None'] + list(df['Prompt name'])
|
choices = ['None'] + list(df['Prompt name'])
|
||||||
|
|||||||
78
extensions/sd_api_pictures/README.MD
Normal file
78
extensions/sd_api_pictures/README.MD
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
## Description:
|
||||||
|
TL;DR: Lets the bot answer you with a picture!
|
||||||
|
|
||||||
|
Stable Diffusion API pictures for TextGen, v.1.1.0
|
||||||
|
An extension to [oobabooga's textgen-webui](https://github.com/oobabooga/text-generation-webui) allowing you to receive pics generated by [Automatic1111's SD-WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Interface overview</summary>
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
Load it in the `--chat` mode with `--extension sd_api_pictures` alongside `send_pictures` (it's not really required, but completes the picture, *pun intended*).
|
||||||
|
|
||||||
|
The image generation is triggered either:
|
||||||
|
- manually through the 'Force the picture response' button while in `Manual` or `Immersive/Interactive` modes OR
|
||||||
|
- automatically in `Immersive/Interactive` mode if the words `'send|main|message|me'` are followed by `'image|pic|picture|photo|snap|snapshot|selfie|meme'` in the user's prompt
|
||||||
|
- always on in Picturebook/Adventure mode (if not currently suppressed by 'Suppress the picture response')
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
One needs an available instance of Automatic1111's webui running with an `--api` flag. Ain't tested with a notebook / cloud hosted one but should be possible.
|
||||||
|
To run it locally in parallel on the same machine, specify custom `--listen-port` for either Auto1111's or ooba's webUIs.
|
||||||
|
|
||||||
|
## Features:
|
||||||
|
- API detection (press enter in the API box)
|
||||||
|
- VRAM management (model shuffling)
|
||||||
|
- Three different operation modes (manual, interactive, always-on)
|
||||||
|
- persistent settings via settings.json
|
||||||
|
|
||||||
|
The model input is modified only in the interactive mode; other two are unaffected. The output pic description is presented differently for Picture-book / Adventure mode.
|
||||||
|
|
||||||
|
Connection check (insert the Auto1111's address and press Enter):
|
||||||
|

|
||||||
|
|
||||||
|
### Persistents settings
|
||||||
|
|
||||||
|
Create or modify the `settings.json` in the `text-generation-webui` root directory to override the defaults
|
||||||
|
present in script.py, ex:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sd_api_pictures-manage_VRAM": 1,
|
||||||
|
"sd_api_pictures-save_img": 1,
|
||||||
|
"sd_api_pictures-prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful, (solo:1.1)",
|
||||||
|
"sd_api_pictures-sampler_name": "DPM++ 2M Karras"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
will automatically set the `Manage VRAM` & `Keep original images` checkboxes and change the texts in `Prompt Prefix` and `Sampler name` on load.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Demonstrations:
|
||||||
|
|
||||||
|
Those are examples of the version 1.0.0, but the core functionality is still the same
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Conversation 1</summary>
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Conversation 2</summary>
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
@@ -1,93 +1,151 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import date
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.chat as chat
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from modules.models import reload_model, unload_model
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# parameters which can be customized in settings.json of webui
|
# parameters which can be customized in settings.json of webui
|
||||||
params = {
|
params = {
|
||||||
'enable_SD_api': False,
|
|
||||||
'address': 'http://127.0.0.1:7860',
|
'address': 'http://127.0.0.1:7860',
|
||||||
|
'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
|
||||||
|
'manage_VRAM': False,
|
||||||
'save_img': False,
|
'save_img': False,
|
||||||
'SD_model': 'NeverEndingDream', # not really used right now
|
'SD_model': 'NeverEndingDream', # not used right now
|
||||||
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
|
'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
|
||||||
'negative_prompt': '(worst quality, low quality:1.3)',
|
'negative_prompt': '(worst quality, low quality:1.3)',
|
||||||
'side_length': 512,
|
'width': 512,
|
||||||
'restore_faces': False
|
'height': 512,
|
||||||
|
'restore_faces': False,
|
||||||
|
'seed': -1,
|
||||||
|
'sampler_name': 'DDIM',
|
||||||
|
'steps': 32,
|
||||||
|
'cfg_scale': 7
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def give_VRAM_priority(actor):
|
||||||
|
global shared, params
|
||||||
|
|
||||||
|
if actor == 'SD':
|
||||||
|
unload_model()
|
||||||
|
print("Requesting Auto1111 to re-load last checkpoint used...")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
elif actor == 'LLM':
|
||||||
|
print("Requesting Auto1111 to vacate VRAM...")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
reload_model()
|
||||||
|
|
||||||
|
elif actor == 'set':
|
||||||
|
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
elif actor == 'reset':
|
||||||
|
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
del response
|
||||||
|
|
||||||
|
|
||||||
|
if params['manage_VRAM']:
|
||||||
|
give_VRAM_priority('set')
|
||||||
|
|
||||||
|
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
|
||||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||||
|
|
||||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||||
picture_response = False # specifies if the next model response should appear as a picture
|
picture_response = False # specifies if the next model response should appear as a picture
|
||||||
pic_id = 0
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
|
||||||
|
def triggers_are_in(string):
|
||||||
|
string = remove_surrounded_chars(string)
|
||||||
|
# regex searches for send|main|message|me (at the end of the word) followed by
|
||||||
|
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
|
||||||
|
# (?aims) are regex parser flags
|
||||||
|
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
they are fed into the model.
|
they are fed into the model.
|
||||||
"""
|
"""
|
||||||
global params, picture_response
|
|
||||||
if not params['enable_SD_api']:
|
global params
|
||||||
|
|
||||||
|
if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
|
||||||
return string
|
return string
|
||||||
|
|
||||||
commands = ['send', 'mail', 'me']
|
if triggers_are_in(string): # if we're in it, check for trigger words
|
||||||
mediums = ['image', 'pic', 'picture', 'photo']
|
toggle_generation(True)
|
||||||
subjects = ['yourself', 'own']
|
string = string.lower()
|
||||||
lowstr = string.lower()
|
if "of" in string:
|
||||||
|
subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
|
||||||
# TODO: refactor out to separate handler and also replace detection with a regexp
|
string = "Please provide a detailed and vivid description of " + subject
|
||||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
else:
|
||||||
picture_response = True
|
string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now"
|
||||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
|
||||||
shared.processing_message = "*Is sending a picture...*"
|
|
||||||
string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
|
|
||||||
if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
|
|
||||||
string = "Please provide a detailed and vivid description of how you look and what you are wearing"
|
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
# Get and save the Stable Diffusion-generated picture
|
# Get and save the Stable Diffusion-generated picture
|
||||||
def get_SD_pictures(description):
|
def get_SD_pictures(description):
|
||||||
|
|
||||||
global params, pic_id
|
global params
|
||||||
|
|
||||||
|
if params['manage_VRAM']:
|
||||||
|
give_VRAM_priority('SD')
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"prompt": params['prompt_prefix'] + description,
|
"prompt": params['prompt_prefix'] + description,
|
||||||
"seed": -1,
|
"seed": params['seed'],
|
||||||
"sampler_name": "DPM++ 2M Karras",
|
"sampler_name": params['sampler_name'],
|
||||||
"steps": 32,
|
"steps": params['steps'],
|
||||||
"cfg_scale": 7,
|
"cfg_scale": params['cfg_scale'],
|
||||||
"width": params['side_length'],
|
"width": params['width'],
|
||||||
"height": params['side_length'],
|
"height": params['height'],
|
||||||
"restore_faces": params['restore_faces'],
|
"restore_faces": params['restore_faces'],
|
||||||
"negative_prompt": params['negative_prompt']
|
"negative_prompt": params['negative_prompt']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f'Prompting the image generator via the API on {params["address"]}...')
|
||||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
r = response.json()
|
r = response.json()
|
||||||
|
|
||||||
visible_result = ""
|
visible_result = ""
|
||||||
for img_str in r['images']:
|
for img_str in r['images']:
|
||||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
||||||
if params['save_img']:
|
if params['save_img']:
|
||||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
|
||||||
|
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
|
||||||
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
image.save(output_file.as_posix())
|
image.save(output_file.as_posix())
|
||||||
pic_id += 1
|
visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
|
||||||
|
else:
|
||||||
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
image.thumbnail((300, 300))
|
image.thumbnail((300, 300))
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
@@ -97,6 +155,9 @@ def get_SD_pictures(description):
|
|||||||
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
||||||
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
||||||
|
|
||||||
|
if params['manage_VRAM']:
|
||||||
|
give_VRAM_priority('LLM')
|
||||||
|
|
||||||
return visible_result
|
return visible_result
|
||||||
|
|
||||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||||
@@ -105,7 +166,8 @@ def output_modifier(string):
|
|||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
"""
|
"""
|
||||||
global pic_id, picture_response, streaming_state
|
|
||||||
|
global picture_response, params
|
||||||
|
|
||||||
if not picture_response:
|
if not picture_response:
|
||||||
return string
|
return string
|
||||||
@@ -118,17 +180,19 @@ def output_modifier(string):
|
|||||||
|
|
||||||
if string == '':
|
if string == '':
|
||||||
string = 'no viable description in reply, try regenerating'
|
string = 'no viable description in reply, try regenerating'
|
||||||
|
return string
|
||||||
|
|
||||||
# I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
|
text = ""
|
||||||
text = f'*Description: "{string}"*'
|
if (params['mode'] < 2):
|
||||||
|
toggle_generation(False)
|
||||||
|
text = f'*Sends a picture which portrays: “{string}”*'
|
||||||
|
else:
|
||||||
|
text = string
|
||||||
|
|
||||||
image = get_SD_pictures(string)
|
string = get_SD_pictures(string) + "\n" + text
|
||||||
|
|
||||||
picture_response = False
|
return string
|
||||||
|
|
||||||
shared.processing_message = "*Is typing...*"
|
|
||||||
shared.args.no_stream = streaming_state
|
|
||||||
return image + "\n" + text
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
@@ -139,41 +203,92 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
def force_pic():
|
|
||||||
global picture_response
|
def toggle_generation(*args):
|
||||||
picture_response = True
|
global picture_response, shared, streaming_state
|
||||||
|
|
||||||
|
if not args:
|
||||||
|
picture_response = not picture_response
|
||||||
|
else:
|
||||||
|
picture_response = args[0]
|
||||||
|
|
||||||
|
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||||
|
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
||||||
|
|
||||||
|
|
||||||
|
def filter_address(address):
|
||||||
|
address = address.strip()
|
||||||
|
# address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
|
||||||
|
address = re.sub('\/$', '', address) # remove trailing /s
|
||||||
|
if not address.startswith('http'):
|
||||||
|
address = 'http://' + address
|
||||||
|
return address
|
||||||
|
|
||||||
|
|
||||||
|
def SD_api_address_update(address):
|
||||||
|
|
||||||
|
global params
|
||||||
|
|
||||||
|
msg = "✔️ SD API is found on:"
|
||||||
|
address = filter_address(address)
|
||||||
|
params.update({"address": address})
|
||||||
|
try:
|
||||||
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
||||||
|
response.raise_for_status()
|
||||||
|
# r = response.json()
|
||||||
|
except:
|
||||||
|
msg = "❌ No SD API endpoint on:"
|
||||||
|
|
||||||
|
return gr.Textbox.update(label=msg)
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
with gr.Accordion("Stable Diffusion api integration", open=True):
|
# gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
|
||||||
|
with gr.Accordion("Parameters", open=True):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
|
||||||
enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
|
mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index")
|
||||||
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
|
with gr.Column(scale=1, min_width=300):
|
||||||
with gr.Column():
|
manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
|
||||||
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
|
save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
|
||||||
|
|
||||||
with gr.Row():
|
force_pic = gr.Button("Force the picture response")
|
||||||
force_btn = gr.Button("Force the next response to be a picture")
|
suppr_pic = gr.Button("Suppress the picture response")
|
||||||
generate_now_btn = gr.Button("Generate an image response to the input")
|
|
||||||
|
|
||||||
with gr.Accordion("Generation parameters", open=False):
|
with gr.Accordion("Generation parameters", open=False):
|
||||||
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||||
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
|
sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler')
|
||||||
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
|
with gr.Column():
|
||||||
|
width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
|
||||||
|
height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
|
||||||
|
with gr.Row():
|
||||||
|
steps = gr.Number(label="Steps:", value=params['steps'])
|
||||||
|
seed = gr.Number(label="Seed:", value=params['seed'])
|
||||||
|
cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
|
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||||
|
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
||||||
|
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
|
||||||
|
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
|
||||||
|
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
|
||||||
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
||||||
address.change(lambda x: params.update({"address": x}), address, None)
|
|
||||||
|
address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
|
||||||
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
|
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
|
||||||
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
|
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
|
||||||
dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
|
width.change(lambda x: params.update({"width": x}), width, None)
|
||||||
# model.change(lambda x: params.update({"SD_model": x}), model, None)
|
height.change(lambda x: params.update({"height": x}), height, None)
|
||||||
|
|
||||||
force_btn.click(force_pic)
|
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
|
||||||
generate_now_btn.click(force_pic)
|
steps.change(lambda x: params.update({"steps": x}), steps, None)
|
||||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
seed.change(lambda x: params.update({"seed": x}), seed, None)
|
||||||
|
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
|
||||||
|
|
||||||
|
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
|
||||||
|
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)
|
||||||
|
|||||||
@@ -17,13 +17,15 @@ input_hijack = {
|
|||||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||||
|
|
||||||
|
|
||||||
def caption_image(raw_image):
|
def caption_image(raw_image):
|
||||||
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||||
out = model.generate(**inputs, max_new_tokens=100)
|
out = model.generate(**inputs, max_new_tokens=100)
|
||||||
return processor.decode(out[0], skip_special_tokens=True)
|
return processor.decode(out[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_picture(picture, name1, name2):
|
def generate_chat_picture(picture, name1, name2):
|
||||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
text = f'*{name1} sends {name2} a picture that contains the following: “{caption_image(picture)}”*'
|
||||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
picture.thumbnail((300, 300))
|
picture.thumbnail((300, 300))
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
@@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
|
|||||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||||
return text, visible_text
|
return text, visible_text
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
ipython
|
ipython
|
||||||
|
num2words
|
||||||
omegaconf
|
omegaconf
|
||||||
pydub
|
pydub
|
||||||
PyYAML
|
PyYAML
|
||||||
torch
|
|
||||||
torchaudio
|
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import re
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.chat as chat
|
|
||||||
import modules.shared as shared
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from extensions.silero_tts import tts_preprocessor
|
||||||
|
from modules import chat, shared
|
||||||
|
from modules.html_generator import chat_html_wrapper
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'activate': True,
|
'activate': True,
|
||||||
'speaker': 'en_56',
|
'speaker': 'en_56',
|
||||||
@@ -20,6 +22,7 @@ params = {
|
|||||||
'autoplay': True,
|
'autoplay': True,
|
||||||
'voice_pitch': 'medium',
|
'voice_pitch': 'medium',
|
||||||
'voice_speed': 'medium',
|
'voice_speed': 'medium',
|
||||||
|
'local_cache_path': '' # User can override the default cache path to something other via settings.json
|
||||||
}
|
}
|
||||||
|
|
||||||
current_params = params.copy()
|
current_params = params.copy()
|
||||||
@@ -37,26 +40,31 @@ table = str.maketrans({
|
|||||||
'"': """,
|
'"': """,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def xmlesc(txt):
|
def xmlesc(txt):
|
||||||
return txt.translate(table)
|
return txt.translate(table)
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
|
torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
|
||||||
|
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
||||||
|
if Path(model_path).is_file():
|
||||||
|
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
||||||
|
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
||||||
|
else:
|
||||||
|
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
||||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||||
model.to(params['device'])
|
model.to(params['device'])
|
||||||
return model
|
return model
|
||||||
model = load_model()
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
|
||||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
|
||||||
|
|
||||||
def remove_tts_from_history(name1, name2):
|
def remove_tts_from_history(name1, name2, mode):
|
||||||
for i, entry in enumerate(shared.history['internal']):
|
for i, entry in enumerate(shared.history['internal']):
|
||||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
def toggle_text_in_history(name1, name2):
|
|
||||||
|
def toggle_text_in_history(name1, name2, mode):
|
||||||
for i, entry in enumerate(shared.history['visible']):
|
for i, entry in enumerate(shared.history['visible']):
|
||||||
visible_reply = entry[1]
|
visible_reply = entry[1]
|
||||||
if visible_reply.startswith('<audio'):
|
if visible_reply.startswith('<audio'):
|
||||||
@@ -65,7 +73,8 @@ def toggle_text_in_history(name1, name2):
|
|||||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||||
else:
|
else:
|
||||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
@@ -81,6 +90,7 @@ def input_modifier(string):
|
|||||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@@ -94,15 +104,11 @@ def output_modifier(string):
|
|||||||
current_params = params.copy()
|
current_params = params.copy()
|
||||||
break
|
break
|
||||||
|
|
||||||
if params['activate'] == False:
|
if not params['activate']:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
original_string = string
|
original_string = string
|
||||||
string = remove_surrounded_chars(string)
|
string = tts_preprocessor.preprocess(string)
|
||||||
string = string.replace('"', '')
|
|
||||||
string = string.replace('“', '')
|
|
||||||
string = string.replace('\n', ' ')
|
|
||||||
string = string.strip()
|
|
||||||
|
|
||||||
if string == '':
|
if string == '':
|
||||||
string = '*Empty reply, try regenerating*'
|
string = '*Empty reply, try regenerating*'
|
||||||
@@ -121,6 +127,7 @@ def output_modifier(string):
|
|||||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
@@ -130,17 +137,25 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
global model
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
with gr.Accordion("Silero TTS"):
|
with gr.Accordion("Silero TTS"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||||
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
||||||
|
|
||||||
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
||||||
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
||||||
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
convert = gr.Button('Permanently replace audios with the message texts')
|
convert = gr.Button('Permanently replace audios with the message texts')
|
||||||
convert_cancel = gr.Button('Cancel', visible=False)
|
convert_cancel = gr.Button('Cancel', visible=False)
|
||||||
@@ -150,13 +165,13 @@ def ui():
|
|||||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||||
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||||
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||||
|
|
||||||
# Toggle message text in history
|
# Toggle message text in history
|
||||||
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
||||||
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
|
|||||||
81
extensions/silero_tts/test_tts.py
Normal file
81
extensions/silero_tts/test_tts.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tts_preprocessor
|
||||||
|
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'activate': True,
|
||||||
|
'speaker': 'en_49',
|
||||||
|
'language': 'en',
|
||||||
|
'model_id': 'v3_en',
|
||||||
|
'sample_rate': 48000,
|
||||||
|
'device': 'cpu',
|
||||||
|
'show_text': True,
|
||||||
|
'autoplay': True,
|
||||||
|
'voice_pitch': 'medium',
|
||||||
|
'voice_speed': 'medium',
|
||||||
|
}
|
||||||
|
|
||||||
|
current_params = params.copy()
|
||||||
|
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||||
|
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||||
|
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||||
|
|
||||||
|
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||||
|
table = str.maketrans({
|
||||||
|
"<": "<",
|
||||||
|
">": ">",
|
||||||
|
"&": "&",
|
||||||
|
"'": "'",
|
||||||
|
'"': """,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def xmlesc(txt):
|
||||||
|
return txt.translate(table)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||||
|
model.to(params['device'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
|
||||||
|
def output_modifier(string):
|
||||||
|
"""
|
||||||
|
This function is applied to the model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global model, current_params
|
||||||
|
|
||||||
|
original_string = string
|
||||||
|
string = tts_preprocessor.preprocess(string)
|
||||||
|
processed_string = string
|
||||||
|
|
||||||
|
if string == '':
|
||||||
|
string = '*Empty reply, try regenerating*'
|
||||||
|
else:
|
||||||
|
output_file = Path(f'extensions/silero_tts/outputs/test_{int(time.time())}.wav')
|
||||||
|
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||||
|
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||||
|
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||||
|
|
||||||
|
autoplay = 'autoplay' if params['autoplay'] else ''
|
||||||
|
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
||||||
|
|
||||||
|
if params['show_text']:
|
||||||
|
string += f'\n\n{original_string}\n\nProcessed:\n{processed_string}'
|
||||||
|
|
||||||
|
print(string)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
output_modifier(sys.argv[1])
|
||||||
194
extensions/silero_tts/tts_preprocessor.py
Normal file
194
extensions/silero_tts/tts_preprocessor.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from num2words import num2words
|
||||||
|
|
||||||
|
punctuation = r'[\s,.?!/)\'\]>]'
|
||||||
|
alphabet_map = {
|
||||||
|
"A": " Ei ",
|
||||||
|
"B": " Bee ",
|
||||||
|
"C": " See ",
|
||||||
|
"D": " Dee ",
|
||||||
|
"E": " Eee ",
|
||||||
|
"F": " Eff ",
|
||||||
|
"G": " Jee ",
|
||||||
|
"H": " Eich ",
|
||||||
|
"I": " Eye ",
|
||||||
|
"J": " Jay ",
|
||||||
|
"K": " Kay ",
|
||||||
|
"L": " El ",
|
||||||
|
"M": " Emm ",
|
||||||
|
"N": " Enn ",
|
||||||
|
"O": " Ohh ",
|
||||||
|
"P": " Pee ",
|
||||||
|
"Q": " Queue ",
|
||||||
|
"R": " Are ",
|
||||||
|
"S": " Ess ",
|
||||||
|
"T": " Tee ",
|
||||||
|
"U": " You ",
|
||||||
|
"V": " Vee ",
|
||||||
|
"W": " Double You ",
|
||||||
|
"X": " Ex ",
|
||||||
|
"Y": " Why ",
|
||||||
|
"Z": " Zed " # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(string):
|
||||||
|
# the order for some of these matter
|
||||||
|
# For example, you need to remove the commas in numbers before expanding them
|
||||||
|
string = remove_surrounded_chars(string)
|
||||||
|
string = string.replace('"', '')
|
||||||
|
string = string.replace('\u201D', '').replace('\u201C', '') # right and left quote
|
||||||
|
string = string.replace('\u201F', '') # italic looking quote
|
||||||
|
string = string.replace('\n', ' ')
|
||||||
|
string = convert_num_locale(string)
|
||||||
|
string = replace_negative(string)
|
||||||
|
string = replace_roman(string)
|
||||||
|
string = hyphen_range_to(string)
|
||||||
|
string = num_to_words(string)
|
||||||
|
|
||||||
|
# TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
|
||||||
|
# try to say the abbreviation or spell it out as I've done below is not agreed upon
|
||||||
|
|
||||||
|
# For now, expand abbreviations to pronunciations
|
||||||
|
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation
|
||||||
|
string = replace_abbreviations(string)
|
||||||
|
string = replace_lowercase_abbreviations(string)
|
||||||
|
|
||||||
|
# cleanup whitespaces
|
||||||
|
# remove whitespace before punctuation
|
||||||
|
string = re.sub(rf'\s+({punctuation})', r'\1', string)
|
||||||
|
string = string.strip()
|
||||||
|
# compact whitespace
|
||||||
|
string = ' '.join(string.split())
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def remove_surrounded_chars(string):
|
||||||
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
|
return re.sub(r'\*[^*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_num_locale(text):
|
||||||
|
# This detects locale and converts it to American without comma separators
|
||||||
|
pattern = re.compile(r'(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)')
|
||||||
|
result = text
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start] + result[start:end].replace('.', '').replace(',', '.') + result[end:len(result)]
|
||||||
|
|
||||||
|
# removes comma separators from existing American numbers
|
||||||
|
pattern = re.compile(r'(\d),(\d)')
|
||||||
|
result = pattern.sub(r'\1\2', result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_negative(string):
|
||||||
|
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five
|
||||||
|
return re.sub(rf'(\s)(-)(\d+)({punctuation})', r'\1negative \3\4', string)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_roman(string):
|
||||||
|
# find a string of roman numerals.
|
||||||
|
# Only 2 or more, to avoid capturing I and single character abbreviations, like names
|
||||||
|
pattern = re.compile(rf'\s[IVXLCDM]{{2,}}{punctuation}')
|
||||||
|
result = string
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start + 1] + str(roman_to_int(result[start + 1:end - 1])) + result[end - 1:len(result)]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def roman_to_int(s):
|
||||||
|
rom_val = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
|
||||||
|
int_val = 0
|
||||||
|
for i in range(len(s)):
|
||||||
|
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]:
|
||||||
|
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]]
|
||||||
|
else:
|
||||||
|
int_val += rom_val[s[i]]
|
||||||
|
return int_val
|
||||||
|
|
||||||
|
|
||||||
|
def hyphen_range_to(text):
|
||||||
|
pattern = re.compile(r'(\d+)[-–](\d+)')
|
||||||
|
result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def num_to_words(text):
|
||||||
|
# 1000 or 10.23
|
||||||
|
pattern = re.compile(r'\d+\.\d+|\d+')
|
||||||
|
result = pattern.sub(lambda x: num2words(float(x.group())), text)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_abbreviations(string):
|
||||||
|
# abbreviations 1 to 4 characters long. It will get things like A and I, but those are pronounced with their letter
|
||||||
|
pattern = re.compile(rf'(^|[\s(.\'\[<])([A-Z]{{1,4}})({punctuation}|$)')
|
||||||
|
result = string
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_lowercase_abbreviations(string):
|
||||||
|
# abbreviations 1 to 4 characters long, separated by dots i.e. e.g.
|
||||||
|
pattern = re.compile(rf'(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({punctuation}|$)')
|
||||||
|
result = string
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start] + replace_abbreviation(result[start:end].upper()) + result[end:len(result)]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_abbreviation(string):
|
||||||
|
result = ""
|
||||||
|
for char in string:
|
||||||
|
result += match_mapping(char)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def match_mapping(char):
|
||||||
|
for mapping in alphabet_map.keys():
|
||||||
|
if char == mapping:
|
||||||
|
return alphabet_map[char]
|
||||||
|
|
||||||
|
return char
|
||||||
|
|
||||||
|
|
||||||
|
def __main__(args):
|
||||||
|
print(preprocess(args[1]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
__main__(sys.argv)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import speech_recognition as sr
|
import speech_recognition as sr
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
input_hijack = {
|
input_hijack = {
|
||||||
'state': False,
|
'state': False,
|
||||||
@@ -7,7 +8,7 @@ input_hijack = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_stt(audio, text_state=""):
|
def do_stt(audio):
|
||||||
transcription = ""
|
transcription = ""
|
||||||
r = sr.Recognizer()
|
r = sr.Recognizer()
|
||||||
|
|
||||||
@@ -21,34 +22,23 @@ def do_stt(audio, text_state=""):
|
|||||||
except sr.RequestError as e:
|
except sr.RequestError as e:
|
||||||
print("Could not request results from Whisper", e)
|
print("Could not request results from Whisper", e)
|
||||||
|
|
||||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
return transcription
|
||||||
|
|
||||||
text_state += transcription + " "
|
|
||||||
return text_state, text_state
|
|
||||||
|
|
||||||
|
|
||||||
def update_hijack(val):
|
def auto_transcribe(audio, auto_submit):
|
||||||
input_hijack.update({"state": True, "value": [val, val]})
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def auto_transcribe(audio, audio_auto, text_state=""):
|
|
||||||
if audio is None:
|
if audio is None:
|
||||||
return "", ""
|
return "", ""
|
||||||
if audio_auto:
|
|
||||||
return do_stt(audio, text_state)
|
transcription = do_stt(audio)
|
||||||
return "", ""
|
if auto_submit:
|
||||||
|
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||||
|
|
||||||
|
return transcription, None
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
tr_state = gr.State(value="")
|
|
||||||
output_transcription = gr.Textbox(label="STT-Input",
|
|
||||||
placeholder="Speech Preview. Click \"Generate\" to send",
|
|
||||||
interactive=True)
|
|
||||||
output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
|
|
||||||
audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
audio = gr.Audio(source="microphone")
|
audio = gr.Audio(source="microphone")
|
||||||
audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
|
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=True)
|
||||||
transcribe_button = gr.Button(value="Transcribe")
|
audio.change(fn=auto_transcribe, inputs=[audio, auto_submit], outputs=[shared.gradio['textbox'], audio])
|
||||||
transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])
|
audio.change(None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -16,9 +17,11 @@ from quant import make_quant
|
|||||||
|
|
||||||
|
|
||||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||||
config = AutoConfig.from_pretrained(model)
|
|
||||||
def noop(*args, **kwargs):
|
def noop(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
torch.nn.init.kaiming_uniform_ = noop
|
torch.nn.init.kaiming_uniform_ = noop
|
||||||
torch.nn.init.uniform_ = noop
|
torch.nn.init.uniform_ = noop
|
||||||
torch.nn.init.normal_ = noop
|
torch.nn.init.normal_ = noop
|
||||||
@@ -33,21 +36,37 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||||||
for name in exclude_layers:
|
for name in exclude_layers:
|
||||||
if name in layers:
|
if name in layers:
|
||||||
del layers[name]
|
del layers[name]
|
||||||
make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
|
|
||||||
|
gptq_args = inspect.getfullargspec(make_quant).args
|
||||||
|
|
||||||
|
make_quant_kwargs = {
|
||||||
|
'module': model,
|
||||||
|
'names': layers,
|
||||||
|
'bits': wbits,
|
||||||
|
}
|
||||||
|
if 'groupsize' in gptq_args:
|
||||||
|
make_quant_kwargs['groupsize'] = groupsize
|
||||||
|
if 'faster' in gptq_args:
|
||||||
|
make_quant_kwargs['faster'] = faster_kernel
|
||||||
|
if 'kernel_switch_threshold' in gptq_args:
|
||||||
|
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||||
|
|
||||||
|
make_quant(**make_quant_kwargs)
|
||||||
|
|
||||||
del layers
|
del layers
|
||||||
|
|
||||||
print('Loading model ...')
|
print('Loading model ...')
|
||||||
if checkpoint.endswith('.safetensors'):
|
if checkpoint.endswith('.safetensors'):
|
||||||
from safetensors.torch import load_file as safe_load
|
from safetensors.torch import load_file as safe_load
|
||||||
model.load_state_dict(safe_load(checkpoint))
|
model.load_state_dict(safe_load(checkpoint), strict=False)
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(torch.load(checkpoint))
|
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||||
model.seqlen = 2048
|
model.seqlen = 2048
|
||||||
print('Done.')
|
print('Done.')
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_quantized(model_name):
|
def load_quantized(model_name):
|
||||||
if not shared.args.model_type:
|
if not shared.args.model_type:
|
||||||
# Try to determine model type from model name
|
# Try to determine model type from model name
|
||||||
@@ -81,10 +100,10 @@ def load_quantized(model_name):
|
|||||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||||
pt_path = None
|
pt_path = None
|
||||||
|
|
||||||
if len(found_pts) == 1:
|
if len(found_pts) > 0:
|
||||||
pt_path = found_pts[0]
|
pt_path = found_pts[-1]
|
||||||
elif len(found_safetensors) == 1:
|
elif len(found_safetensors) > 0:
|
||||||
pt_path = found_safetensors[0]
|
pt_path = found_safetensors[-1]
|
||||||
else:
|
else:
|
||||||
if path_to_model.name.lower().startswith('llama-7b'):
|
if path_to_model.name.lower().startswith('llama-7b'):
|
||||||
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
||||||
@@ -100,13 +119,14 @@ def load_quantized(model_name):
|
|||||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
print(f"Found {path}")
|
|
||||||
pt_path = path
|
pt_path = path
|
||||||
break
|
break
|
||||||
|
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||||
exit()
|
exit()
|
||||||
|
else:
|
||||||
|
print(f"Found the following quantized model: {pt_path}")
|
||||||
|
|
||||||
# qwopqwop200's offload
|
# qwopqwop200's offload
|
||||||
if model_type == 'llama' and shared.args.pre_layer:
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
|
|||||||
@@ -4,15 +4,9 @@ import torch
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.models import load_model
|
from modules.models import reload_model
|
||||||
from modules.text_generation import clear_torch_cache
|
|
||||||
|
|
||||||
|
|
||||||
def reload_model():
|
|
||||||
shared.model = shared.tokenizer = None
|
|
||||||
clear_torch_cache()
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_name):
|
def add_lora_to_model(lora_name):
|
||||||
|
|
||||||
# If a LoRA had been previously loaded, or if we want
|
# If a LoRA had been previously loaded, or if we want
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class RWKVModel:
|
|||||||
reply += token
|
reply += token
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
class RWKVTokenizer:
|
class RWKVTokenizer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
|
|||||||
for i in generate_reply(params[0], generate_params):
|
for i in generate_reply(params[0], generate_params):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
def create_apis():
|
def create_apis():
|
||||||
t1 = gr.Textbox(visible=False)
|
t1 = gr.Textbox(visible=False)
|
||||||
t2 = gr.Textbox(visible=False)
|
t2 = gr.Textbox(visible=False)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
def __init__(self, callback_func=None):
|
def __init__(self, callback_func=None):
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
@@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
|
|||||||
self.callback_func(input_ids[0])
|
self.callback_func(input_ids[0])
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Iteratorize:
|
class Iteratorize:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -96,6 +98,7 @@ class Iteratorize:
|
|||||||
self.stop_now = True
|
self.stop_now = True
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
def clear_torch_cache():
|
def clear_torch_cache():
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if not shared.args.cpu:
|
if not shared.args.cpu:
|
||||||
|
|||||||
230
modules/chat.py
230
modules/chat.py
@@ -12,55 +12,59 @@ from PIL import Image
|
|||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import (fix_newlines, chat_html_wrapper,
|
from modules.html_generator import (chat_html_wrapper, fix_newlines,
|
||||||
make_thumbnail)
|
make_thumbnail)
|
||||||
from modules.text_generation import (encode, generate_reply,
|
from modules.text_generation import (encode, generate_reply,
|
||||||
get_max_prompt_length)
|
get_max_prompt_length)
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
|
def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
|
||||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
|
||||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
|
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
|
is_instruct = state['mode'] == 'instruct'
|
||||||
user_input = fix_newlines(user_input)
|
rows = [f"{state['context'].strip()}\n"]
|
||||||
rows = [f"{context.strip()}\n"]
|
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
|
chat_prompt_size = state['chat_prompt_size']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||||
|
|
||||||
if is_instruct:
|
if is_instruct:
|
||||||
prefix1 = f"{name1}\n"
|
prefix1 = f"{state['name1']}\n"
|
||||||
prefix2 = f"{name2}\n"
|
prefix2 = f"{state['name2']}\n"
|
||||||
else:
|
else:
|
||||||
prefix1 = f"{name1}: "
|
prefix1 = f"{state['name1']}: "
|
||||||
prefix2 = f"{name2}: "
|
prefix2 = f"{state['name2']}: "
|
||||||
|
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(shared.history['internal']) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
if _continue and i == len(shared.history['internal']) - 1:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||||
|
else:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||||
string = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|
||||||
if impersonate:
|
if impersonate:
|
||||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||||
limit = 2
|
limit = 2
|
||||||
|
elif _continue:
|
||||||
|
limit = 3
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
|
user_input = fix_newlines(user_input)
|
||||||
if len(user_input) > 0:
|
if len(user_input) > 0:
|
||||||
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
|
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
||||||
|
|
||||||
# Adding the Character prefix
|
# Adding the Character prefix
|
||||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||||
limit = 3
|
limit = 3
|
||||||
|
|
||||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
|
||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
prompt = ''.join(rows)
|
prompt = ''.join(rows)
|
||||||
|
|
||||||
@@ -69,16 +73,27 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
else:
|
else:
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|
||||||
next_character_found = False
|
|
||||||
|
|
||||||
if stop_at_newline:
|
def get_stopping_strings(state):
|
||||||
|
if state['mode'] == 'instruct':
|
||||||
|
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||||
|
else:
|
||||||
|
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||||
|
stopping_strings += state['custom_stopping_strings']
|
||||||
|
return stopping_strings
|
||||||
|
|
||||||
|
|
||||||
|
def extract_message_from_reply(reply, state):
|
||||||
|
next_character_found = False
|
||||||
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
|
if state['stop_at_newline']:
|
||||||
lines = reply.split('\n')
|
lines = reply.split('\n')
|
||||||
reply = lines[0].strip()
|
reply = lines[0].strip()
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
next_character_found = True
|
next_character_found = True
|
||||||
else:
|
else:
|
||||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
for string in stopping_strings:
|
||||||
idx = reply.find(string)
|
idx = reply.find(string)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
reply = reply[:idx]
|
reply = reply[:idx]
|
||||||
@@ -87,27 +102,32 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
# If something like "\nYo" is generated just before "\nYou:"
|
# If something like "\nYo" is generated just before "\nYou:"
|
||||||
# is completed, trim it
|
# is completed, trim it
|
||||||
if not next_character_found:
|
if not next_character_found:
|
||||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
for string in stopping_strings:
|
||||||
for j in range(len(string) - 1, 0, -1):
|
for j in range(len(string) - 1, 0, -1):
|
||||||
if reply[-j:] == string[:j]:
|
if reply[-j:] == string[:j]:
|
||||||
reply = reply[:-j]
|
reply = reply[:-j]
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
reply = fix_newlines(reply)
|
reply = fix_newlines(reply)
|
||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
|
||||||
|
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||||
|
|
||||||
|
# Defining some variables
|
||||||
|
cumulative_reply = ''
|
||||||
|
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||||
just_started = True
|
just_started = True
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
visible_text = custom_generate_chat_prompt = None
|
||||||
name1_original = name1
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
stopping_strings = get_stopping_strings(state)
|
||||||
name1 = "You"
|
|
||||||
|
|
||||||
# Check if any extension wants to hijack this function call
|
# Check if any extension wants to hijack this function call
|
||||||
visible_text = None
|
|
||||||
custom_generate_chat_prompt = None
|
|
||||||
for extension, _ in extensions_module.iterator():
|
for extension, _ in extensions_module.iterator():
|
||||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
|
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||||
extension.input_hijack['state'] = False
|
extension.input_hijack['state'] = False
|
||||||
text, visible_text = extension.input_hijack['value']
|
text, visible_text = extension.input_hijack['value']
|
||||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||||
@@ -115,28 +135,29 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
if visible_text is None:
|
if visible_text is None:
|
||||||
visible_text = text
|
visible_text = text
|
||||||
|
if not _continue:
|
||||||
text = apply_extensions(text, "input")
|
text = apply_extensions(text, "input")
|
||||||
|
|
||||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
# Generating the prompt
|
||||||
|
kwargs = {'_continue': _continue}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = custom_generate_chat_prompt(text, state, **kwargs)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
if not regenerate:
|
if not any((regenerate, _continue)):
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
cumulative_reply = ''
|
for i in range(state['chat_generation_attempts']):
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
|
|
||||||
# Extracting the reply
|
# Extracting the reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
visible_reply = apply_extensions(visible_reply, "output")
|
visible_reply = apply_extensions(visible_reply, "output")
|
||||||
|
|
||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
@@ -145,9 +166,15 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
return shared.history['visible']
|
return shared.history['visible']
|
||||||
if just_started:
|
if just_started:
|
||||||
just_started = False
|
just_started = False
|
||||||
|
if not _continue:
|
||||||
shared.history['internal'].append(['', ''])
|
shared.history['internal'].append(['', ''])
|
||||||
shared.history['visible'].append(['', ''])
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
|
if _continue:
|
||||||
|
sep = list(map(lambda x: ' ' if len(x) > 0 and x[-1] != ' ' else '', last_reply))
|
||||||
|
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||||
|
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||||
|
else:
|
||||||
shared.history['internal'][-1] = [text, reply]
|
shared.history['internal'][-1] = [text, reply]
|
||||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
@@ -160,23 +187,23 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
|
|
||||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
|
||||||
|
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
def impersonate_wrapper(text, state):
|
||||||
name1 = "You"
|
|
||||||
|
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
# Defining some variables
|
||||||
|
cumulative_reply = ''
|
||||||
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
|
prompt = generate_chat_prompt(text, state, impersonate=True)
|
||||||
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
yield shared.processing_message
|
yield shared.processing_message
|
||||||
|
|
||||||
cumulative_reply = ''
|
for i in range(state['chat_generation_attempts']):
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
yield reply
|
yield reply
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
@@ -186,21 +213,34 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
|
||||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
|
||||||
yield chat_html_wrapper(history, name1, name2, mode)
|
|
||||||
|
|
||||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def cai_chatbot_wrapper(text, state):
|
||||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
for history in chatbot_wrapper(text, state):
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
|
def regenerate_wrapper(text, state):
|
||||||
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
else:
|
else:
|
||||||
last_visible = shared.history['visible'].pop()
|
last_visible = shared.history['visible'].pop()
|
||||||
last_internal = shared.history['internal'].pop()
|
last_internal = shared.history['internal'].pop()
|
||||||
# Yield '*Is typing...*'
|
# Yield '*Is typing...*'
|
||||||
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
|
||||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
|
||||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
|
def continue_wrapper(text, state):
|
||||||
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
else:
|
||||||
|
# Yield ' ...'
|
||||||
|
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
|
||||||
|
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2, mode):
|
def remove_last_message(name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||||
@@ -211,12 +251,14 @@ def remove_last_message(name1, name2, mode):
|
|||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
||||||
|
|
||||||
|
|
||||||
def send_last_reply_to_input():
|
def send_last_reply_to_input():
|
||||||
if len(shared.history['internal']) > 0:
|
if len(shared.history['internal']) > 0:
|
||||||
return shared.history['internal'][-1][1]
|
return shared.history['internal'][-1][1]
|
||||||
else:
|
else:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def replace_last_reply(text, name1, name2, mode):
|
def replace_last_reply(text, name1, name2, mode):
|
||||||
if len(shared.history['visible']) > 0:
|
if len(shared.history['visible']) > 0:
|
||||||
shared.history['visible'][-1][1] = text
|
shared.history['visible'][-1][1] = text
|
||||||
@@ -224,9 +266,26 @@ def replace_last_reply(text, name1, name2, mode):
|
|||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def send_dummy_message(text, name1, name2, mode):
|
||||||
|
shared.history['visible'].append([text, ''])
|
||||||
|
shared.history['internal'].append([apply_extensions(text, "input"), ''])
|
||||||
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def send_dummy_reply(text, name1, name2, mode):
|
||||||
|
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '':
|
||||||
|
shared.history['visible'].append(['', ''])
|
||||||
|
shared.history['internal'].append(['', ''])
|
||||||
|
shared.history['visible'][-1][1] = text
|
||||||
|
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||||
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def clear_html():
|
def clear_html():
|
||||||
return chat_html_wrapper([], "", "")
|
return chat_html_wrapper([], "", "")
|
||||||
|
|
||||||
|
|
||||||
def clear_chat_log(name1, name2, greeting, mode):
|
def clear_chat_log(name1, name2, greeting, mode):
|
||||||
shared.history['visible'] = []
|
shared.history['visible'] = []
|
||||||
shared.history['internal'] = []
|
shared.history['internal'] = []
|
||||||
@@ -235,14 +294,19 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
|
# Save cleared logs
|
||||||
|
save_history(mode)
|
||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def redraw_html(name1, name2, mode):
|
def redraw_html(name1, name2, mode):
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||||
history = []
|
history = []
|
||||||
|
messages = []
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
dialogue = re.sub('<start>', '', dialogue)
|
dialogue = re.sub('<start>', '', dialogue)
|
||||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||||
@@ -251,7 +315,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||||||
if len(idx) == 0:
|
if len(idx) == 0:
|
||||||
return history
|
return history
|
||||||
|
|
||||||
messages = []
|
|
||||||
for i in range(len(idx) - 1):
|
for i in range(len(idx) - 1):
|
||||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||||
messages.append(dialogue[idx[-1]:].strip())
|
messages.append(dialogue[idx[-1]:].strip())
|
||||||
@@ -277,7 +340,15 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
def save_history(timestamp=True):
|
|
||||||
|
def save_history(mode, timestamp=False):
|
||||||
|
# Instruct mode histories should not be saved as if
|
||||||
|
# Alpaca or Vicuna were characters
|
||||||
|
if mode == 'instruct':
|
||||||
|
if not timestamp:
|
||||||
|
return
|
||||||
|
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
|
else:
|
||||||
if timestamp:
|
if timestamp:
|
||||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
else:
|
else:
|
||||||
@@ -286,8 +357,10 @@ def save_history(timestamp=True):
|
|||||||
Path('logs').mkdir()
|
Path('logs').mkdir()
|
||||||
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||||
|
|
||||||
return Path(f'logs/{fname}')
|
return Path(f'logs/{fname}')
|
||||||
|
|
||||||
|
|
||||||
def load_history(file, name1, name2):
|
def load_history(file, name1, name2):
|
||||||
file = file.decode('utf-8')
|
file = file.decode('utf-8')
|
||||||
try:
|
try:
|
||||||
@@ -298,24 +371,16 @@ def load_history(file, name1, name2):
|
|||||||
shared.history['visible'] = j['data_visible']
|
shared.history['visible'] = j['data_visible']
|
||||||
else:
|
else:
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||||
# Compatibility with Pygmalion AI's official web UI
|
|
||||||
elif 'chat' in j:
|
|
||||||
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
|
||||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
|
||||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
|
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
|
||||||
shared.history['visible'][0][0] = ''
|
|
||||||
else:
|
|
||||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
|
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
|
||||||
except:
|
except:
|
||||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||||
|
|
||||||
|
|
||||||
def replace_character_names(text, name1, name2):
|
def replace_character_names(text, name1, name2):
|
||||||
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
||||||
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
||||||
|
|
||||||
|
|
||||||
def build_pygmalion_style_context(data):
|
def build_pygmalion_style_context(data):
|
||||||
context = ""
|
context = ""
|
||||||
if 'char_persona' in data and data['char_persona'] != '':
|
if 'char_persona' in data and data['char_persona'] != '':
|
||||||
@@ -325,6 +390,7 @@ def build_pygmalion_style_context(data):
|
|||||||
context = f"{context.strip()}\n<START>\n"
|
context = f"{context.strip()}\n<START>\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def generate_pfp_cache(character):
|
def generate_pfp_cache(character):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
@@ -337,10 +403,9 @@ def generate_pfp_cache(character):
|
|||||||
return img
|
return img
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_character(character, name1, name2, mode):
|
def load_character(character, name1, name2, mode):
|
||||||
shared.character = character
|
shared.character = character
|
||||||
shared.history['internal'] = []
|
|
||||||
shared.history['visible'] = []
|
|
||||||
context = greeting = end_of_turn = ""
|
context = greeting = end_of_turn = ""
|
||||||
greeting_field = 'greeting'
|
greeting_field = 'greeting'
|
||||||
picture = None
|
picture = None
|
||||||
@@ -385,17 +450,28 @@ def load_character(character, name1, name2, mode):
|
|||||||
greeting = shared.settings['greeting']
|
greeting = shared.settings['greeting']
|
||||||
end_of_turn = shared.settings['end_of_turn']
|
end_of_turn = shared.settings['end_of_turn']
|
||||||
|
|
||||||
|
if mode != 'instruct':
|
||||||
|
shared.history['internal'] = []
|
||||||
|
shared.history['visible'] = []
|
||||||
|
|
||||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||||
elif greeting != "":
|
else:
|
||||||
|
# Insert greeting if it exists
|
||||||
|
if greeting != "":
|
||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
# Create .json log files since they don't already exist
|
||||||
|
save_history(mode)
|
||||||
|
|
||||||
|
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def load_default_history(name1, name2):
|
def load_default_history(name1, name2):
|
||||||
load_character("None", name1, name2, "chat")
|
load_character("None", name1, name2, "chat")
|
||||||
|
|
||||||
|
|
||||||
def upload_character(json_file, img, tavern=False):
|
def upload_character(json_file, img, tavern=False):
|
||||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||||
data = json.loads(json_file)
|
data = json.loads(json_file)
|
||||||
@@ -414,6 +490,7 @@ def upload_character(json_file, img, tavern=False):
|
|||||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||||
return outfile_name
|
return outfile_name
|
||||||
|
|
||||||
|
|
||||||
def upload_tavern_character(img, name1, name2):
|
def upload_tavern_character(img, name1, name2):
|
||||||
_img = Image.open(io.BytesIO(img))
|
_img = Image.open(io.BytesIO(img))
|
||||||
_img.getexif()
|
_img.getexif()
|
||||||
@@ -422,12 +499,13 @@ def upload_tavern_character(img, name1, name2):
|
|||||||
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||||
return upload_character(json.dumps(_json), img, tavern=True)
|
return upload_character(json.dumps(_json), img, tavern=True)
|
||||||
|
|
||||||
|
|
||||||
def upload_your_profile_picture(img, name1, name2, mode):
|
def upload_your_profile_picture(img, name1, name2, mode):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
cache_folder.mkdir()
|
cache_folder.mkdir()
|
||||||
|
|
||||||
if img == None:
|
if img is None:
|
||||||
if Path("cache/pfp_me.png").exists():
|
if Path("cache/pfp_me.png").exists():
|
||||||
Path("cache/pfp_me.png").unlink()
|
Path("cache/pfp_me.png").unlink()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -9,25 +9,32 @@ state = {}
|
|||||||
available_extensions = []
|
available_extensions = []
|
||||||
setup_called = set()
|
setup_called = set()
|
||||||
|
|
||||||
|
|
||||||
def load_extensions():
|
def load_extensions():
|
||||||
global state
|
global state, setup_called
|
||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
print(f'Loading the extension "{name}"... ', end='')
|
print(f'Loading the extension "{name}"... ', end='')
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
|
extension = eval(f"extensions.{name}.script")
|
||||||
|
if extension not in setup_called and hasattr(extension, "setup"):
|
||||||
|
setup_called.add(extension)
|
||||||
|
extension.setup()
|
||||||
state[name] = [True, i]
|
state[name] = [True, i]
|
||||||
print('Ok.')
|
print('Ok.')
|
||||||
except:
|
except:
|
||||||
print('Fail.')
|
print('Fail.')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# This iterator returns the extensions in the order specified in the command-line
|
# This iterator returns the extensions in the order specified in the command-line
|
||||||
def iterator():
|
def iterator():
|
||||||
for name in sorted(state, key=lambda x: state[x][1]):
|
for name in sorted(state, key=lambda x: state[x][1]):
|
||||||
if state[name][0] == True:
|
if state[name][0]:
|
||||||
yield eval(f"extensions.{name}.script"), name
|
yield eval(f"extensions.{name}.script"), name
|
||||||
|
|
||||||
|
|
||||||
# Extension functions that map string -> string
|
# Extension functions that map string -> string
|
||||||
def apply_extensions(text, typ):
|
def apply_extensions(text, typ):
|
||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
@@ -39,6 +46,7 @@ def apply_extensions(text, typ):
|
|||||||
text = extension.bot_prefix_modifier(text)
|
text = extension.bot_prefix_modifier(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def create_extensions_block():
|
def create_extensions_block():
|
||||||
global setup_called
|
global setup_called
|
||||||
|
|
||||||
@@ -51,14 +59,9 @@ def create_extensions_block():
|
|||||||
extension.params[param] = shared.settings[_id]
|
extension.params[param] = shared.settings[_id]
|
||||||
|
|
||||||
should_display_ui = False
|
should_display_ui = False
|
||||||
|
|
||||||
# Running setup function
|
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui"):
|
if hasattr(extension, "ui"):
|
||||||
should_display_ui = True
|
should_display_ui = True
|
||||||
if extension not in setup_called and hasattr(extension, "setup"):
|
|
||||||
setup_called.add(extension)
|
|
||||||
extension.setup()
|
|
||||||
|
|
||||||
# Creating the extension ui elements
|
# Creating the extension ui elements
|
||||||
if should_display_ui:
|
if should_display_ui:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
|
|||||||
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
||||||
instruct_css = f.read()
|
instruct_css = f.read()
|
||||||
|
|
||||||
|
|
||||||
def fix_newlines(string):
|
def fix_newlines(string):
|
||||||
string = string.replace('\n', '\n\n')
|
string = string.replace('\n', '\n\n')
|
||||||
string = re.sub(r"\n{3,}", "\n\n", string)
|
string = re.sub(r"\n{3,}", "\n\n", string)
|
||||||
@@ -31,6 +32,8 @@ def fix_newlines(string):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
# This could probably be generalized and improved
|
# This could probably be generalized and improved
|
||||||
|
|
||||||
|
|
||||||
def convert_to_markdown(string):
|
def convert_to_markdown(string):
|
||||||
string = string.replace('\\begin{code}', '```')
|
string = string.replace('\\begin{code}', '```')
|
||||||
string = string.replace('\\end{code}', '```')
|
string = string.replace('\\end{code}', '```')
|
||||||
@@ -40,11 +43,13 @@ def convert_to_markdown(string):
|
|||||||
string = fix_newlines(string)
|
string = fix_newlines(string)
|
||||||
return markdown.markdown(string, extensions=['fenced_code'])
|
return markdown.markdown(string, extensions=['fenced_code'])
|
||||||
|
|
||||||
|
|
||||||
def generate_basic_html(string):
|
def generate_basic_html(string):
|
||||||
string = convert_to_markdown(string)
|
string = convert_to_markdown(string)
|
||||||
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def process_post(post, c):
|
def process_post(post, c):
|
||||||
t = post.split('\n')
|
t = post.split('\n')
|
||||||
number = t[0].split(' ')[1]
|
number = t[0].split(' ')[1]
|
||||||
@@ -59,6 +64,7 @@ def process_post(post, c):
|
|||||||
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
def generate_4chan_html(f):
|
def generate_4chan_html(f):
|
||||||
posts = []
|
posts = []
|
||||||
post = ''
|
post = ''
|
||||||
@@ -98,6 +104,7 @@ def generate_4chan_html(f):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def make_thumbnail(image):
|
def make_thumbnail(image):
|
||||||
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
||||||
if image.size[1] > 470:
|
if image.size[1] > 470:
|
||||||
@@ -105,6 +112,7 @@ def make_thumbnail(image):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def get_image_cache(path):
|
def get_image_cache(path):
|
||||||
cache_folder = Path("cache")
|
cache_folder = Path("cache")
|
||||||
if not cache_folder.exists():
|
if not cache_folder.exists():
|
||||||
@@ -119,6 +127,7 @@ def get_image_cache(path):
|
|||||||
|
|
||||||
return image_cache[path][1]
|
return image_cache[path][1]
|
||||||
|
|
||||||
|
|
||||||
def generate_instruct_html(history):
|
def generate_instruct_html(history):
|
||||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
||||||
for i, _row in enumerate(history[::-1]):
|
for i, _row in enumerate(history[::-1]):
|
||||||
@@ -151,13 +160,13 @@ def generate_instruct_html(history):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||||
|
|
||||||
# The time.time() is to prevent the brower from caching the image
|
# We use ?name2 and ?time.time() to force the browser to reset caches
|
||||||
suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
|
img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else ''
|
||||||
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
img_me = f'<img src="file/cache/pfp_me.png?{time.time() if reset_cache else ""}">' if Path("cache/pfp_me.png").exists() else ''
|
||||||
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
|
||||||
|
|
||||||
for i, _row in enumerate(history[::-1]):
|
for i, _row in enumerate(history[::-1]):
|
||||||
row = [convert_to_markdown(entry) for entry in _row]
|
row = [convert_to_markdown(entry) for entry in _row]
|
||||||
@@ -200,9 +209,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
|||||||
output += "</div>"
|
output += "</div>"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_html(history, name1, name2):
|
def generate_chat_html(history, name1, name2):
|
||||||
return generate_cai_chat_html(history, name1, name2)
|
return generate_cai_chat_html(history, name1, name2)
|
||||||
|
|
||||||
|
|
||||||
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
||||||
if mode == "cai-chat":
|
if mode == "cai-chat":
|
||||||
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
||||||
|
|||||||
176
modules/llama_attn_hijack.py
Normal file
176
modules/llama_attn_hijack.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
|
if shared.args.xformers:
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
|
except Exception:
|
||||||
|
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_llama_attention():
|
||||||
|
if shared.args.xformers:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
|
print("Replaced attention with xformers_attention")
|
||||||
|
elif shared.args.sdp_attention:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||||
|
print("Replaced attention with sdp_attention")
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
dtype = query_states.dtype
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
|
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
|
else:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def sdp_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
63
modules/llamacpp_model_alternative.py
Normal file
63
modules/llamacpp_model_alternative.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
'''
|
||||||
|
Based on
|
||||||
|
https://github.com/abetlen/llama-cpp-python
|
||||||
|
|
||||||
|
Documentation:
|
||||||
|
https://abetlen.github.io/llama-cpp-python/
|
||||||
|
'''
|
||||||
|
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.callbacks import Iteratorize
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaCppModel:
|
||||||
|
def __init__(self):
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, path):
|
||||||
|
result = self()
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'model_path': str(path),
|
||||||
|
'n_ctx': 2048,
|
||||||
|
'seed': 0,
|
||||||
|
'n_threads': shared.args.threads or None
|
||||||
|
}
|
||||||
|
self.model = Llama(**params)
|
||||||
|
|
||||||
|
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||||
|
return result, result
|
||||||
|
|
||||||
|
def encode(self, string):
|
||||||
|
if type(string) is str:
|
||||||
|
string = string.encode()
|
||||||
|
return self.model.tokenize(string)
|
||||||
|
|
||||||
|
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
||||||
|
if type(context) is str:
|
||||||
|
context = context.encode()
|
||||||
|
tokens = self.model.tokenize(context)
|
||||||
|
|
||||||
|
output = b""
|
||||||
|
count = 0
|
||||||
|
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
||||||
|
text = self.model.detokenize([token])
|
||||||
|
output += text
|
||||||
|
if callback:
|
||||||
|
callback(text.decode())
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
if count >= token_count or (token == self.model.token_eos()):
|
||||||
|
break
|
||||||
|
|
||||||
|
return output.decode()
|
||||||
|
|
||||||
|
def generate_with_streaming(self, **kwargs):
|
||||||
|
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||||
|
reply = ''
|
||||||
|
for token in generator:
|
||||||
|
reply += token
|
||||||
|
yield reply
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import gc
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -10,17 +11,17 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from accelerate import infer_auto_device_map, init_empty_weights
|
from accelerate import infer_auto_device_map, init_empty_weights
|
||||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||||
BitsAndBytesConfig)
|
BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules import llama_attn_hijack
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
local_rank = None
|
|
||||||
|
|
||||||
if shared.args.flexgen:
|
if shared.args.flexgen:
|
||||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||||
|
|
||||||
|
local_rank = None
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||||
@@ -103,7 +104,7 @@ def load_model(model_name):
|
|||||||
|
|
||||||
# llamacpp model
|
# llamacpp model
|
||||||
elif shared.is_llamacpp:
|
elif shared.is_llamacpp:
|
||||||
from modules.llamacpp_model import LlamaCppModel
|
from modules.llamacpp_model_alternative import LlamaCppModel
|
||||||
|
|
||||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
|
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
|
||||||
print(f"llama.cpp weights detected: {model_file}\n")
|
print(f"llama.cpp weights detected: {model_file}\n")
|
||||||
@@ -169,16 +170,46 @@ def load_model(model_name):
|
|||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
|
# Hijack attention with xformers
|
||||||
|
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||||
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
|
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||||
|
# For some people this fixes things, for others it causes an error.
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||||
tokenizer.truncation_side = 'left'
|
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def clear_torch_cache():
|
||||||
|
gc.collect()
|
||||||
|
if not shared.args.cpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model():
|
||||||
|
shared.model = shared.tokenizer = None
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_model():
|
||||||
|
unload_model()
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def load_soft_prompt(name):
|
def load_soft_prompt(name):
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
shared.soft_prompt = False
|
shared.soft_prompt = False
|
||||||
|
|||||||
@@ -34,7 +34,13 @@ settings = {
|
|||||||
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
||||||
'greeting': 'Hello there!',
|
'greeting': 'Hello there!',
|
||||||
'end_of_turn': '',
|
'end_of_turn': '',
|
||||||
|
'custom_stopping_strings': '',
|
||||||
'stop_at_newline': False,
|
'stop_at_newline': False,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'truncation_length_min': 0,
|
||||||
|
'truncation_length_max': 4096,
|
||||||
'chat_prompt_size': 2048,
|
'chat_prompt_size': 2048,
|
||||||
'chat_prompt_size_min': 0,
|
'chat_prompt_size_min': 0,
|
||||||
'chat_prompt_size_max': 2048,
|
'chat_prompt_size_max': 2048,
|
||||||
@@ -44,7 +50,7 @@ settings = {
|
|||||||
'default_extensions': [],
|
'default_extensions': [],
|
||||||
'chat_default_extensions': ["gallery"],
|
'chat_default_extensions': ["gallery"],
|
||||||
'presets': {
|
'presets': {
|
||||||
'default': 'NovelAI-Sphinx Moth',
|
'default': 'Default',
|
||||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||||
'.*pygmalion': 'NovelAI-Storywriter',
|
'.*pygmalion': 'NovelAI-Storywriter',
|
||||||
'.*RWKV': 'Naive',
|
'.*RWKV': 'Naive',
|
||||||
@@ -61,6 +67,7 @@ settings = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
return v
|
return v
|
||||||
@@ -71,6 +78,7 @@ def str2bool(v):
|
|||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||||
|
|
||||||
# Basic settings
|
# Basic settings
|
||||||
@@ -87,7 +95,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
|
|||||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
|
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||||
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
||||||
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
||||||
@@ -96,6 +104,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
|
|||||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
||||||
|
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
|
||||||
|
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
|
||||||
|
|
||||||
# llama.cpp
|
# llama.cpp
|
||||||
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
||||||
@@ -145,5 +155,6 @@ if args.cai_chat:
|
|||||||
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
||||||
args.chat = True
|
args.chat = True
|
||||||
|
|
||||||
|
|
||||||
def is_chat():
|
def is_chat():
|
||||||
return args.chat
|
return args.chat
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import gc
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -12,23 +12,38 @@ from modules.callbacks import (Iteratorize, Stream,
|
|||||||
_SentinelTokenStoppingCriteria)
|
_SentinelTokenStoppingCriteria)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.models import local_rank
|
from modules.models import clear_torch_cache, local_rank
|
||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(tokens):
|
def get_max_prompt_length(state):
|
||||||
max_length = 2048-tokens
|
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||||
return max_length
|
return max_length
|
||||||
|
|
||||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|
||||||
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||||
if shared.args.cpu:
|
|
||||||
|
# This is a hack for making replies more creative.
|
||||||
|
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
|
# Llama adds this extra token when the first character is '\n', and this
|
||||||
|
# compromises the stopping criteria, so we just remove it
|
||||||
|
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
|
# Handling truncation
|
||||||
|
if truncation_length is not None:
|
||||||
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
|
if any((shared.is_RWKV, shared.is_llamacpp, shared.args.cpu)):
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
return input_ids.numpy()
|
return input_ids.numpy()
|
||||||
@@ -40,6 +55,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||||||
else:
|
else:
|
||||||
return input_ids.cuda()
|
return input_ids.cuda()
|
||||||
|
|
||||||
|
|
||||||
def decode(output_ids):
|
def decode(output_ids):
|
||||||
# Open Assistant relies on special tokens like <|endoftext|>
|
# Open Assistant relies on special tokens like <|endoftext|>
|
||||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||||
@@ -49,6 +65,7 @@ def decode(output_ids):
|
|||||||
reply = reply.replace(r'<|endoftext|>', '')
|
reply = reply.replace(r'<|endoftext|>', '')
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def generate_softprompt_input_tensors(input_ids):
|
def generate_softprompt_input_tensors(input_ids):
|
||||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||||
@@ -56,6 +73,7 @@ def generate_softprompt_input_tensors(input_ids):
|
|||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
return inputs_embeds, filler_input_ids
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@@ -64,6 +82,7 @@ def fix_gpt4chan(s):
|
|||||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
# Fix the LaTeX equations in galactica
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
@@ -75,6 +94,7 @@ def fix_galactica(s):
|
|||||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def formatted_outputs(reply, model_name):
|
def formatted_outputs(reply, model_name):
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
if 'galactica' in model_name.lower():
|
if 'galactica' in model_name.lower():
|
||||||
@@ -88,45 +108,48 @@ def formatted_outputs(reply, model_name):
|
|||||||
else:
|
else:
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def clear_torch_cache():
|
|
||||||
gc.collect()
|
|
||||||
if not shared.args.cpu:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def set_manual_seed(seed):
|
def set_manual_seed(seed):
|
||||||
if seed != -1:
|
seed = int(seed)
|
||||||
|
if seed == -1:
|
||||||
|
seed = random.randint(1, 2**31)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
|
||||||
|
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(generate_state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
question = apply_extensions(question, "input")
|
question = apply_extensions(question, 'input')
|
||||||
if shared.args.verbose:
|
|
||||||
print(f"\n\n{question}\n--------------------\n")
|
|
||||||
|
|
||||||
# These models are not part of Hugging Face, so we handle them
|
# These models are not part of Hugging Face, so we handle them
|
||||||
# separately and terminate the function call earlier
|
# separately and terminate the function call earlier
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{question}\n--------------------\n')
|
||||||
|
|
||||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params["token_count"] = generate_state["max_new_tokens"]
|
generate_params['token_count'] = state['max_new_tokens']
|
||||||
try:
|
try:
|
||||||
if shared.args.no_stream:
|
if shared.args.no_stream:
|
||||||
reply = shared.model.generate(context=question, **generate_params)
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
output = original_question + reply
|
output = original_question + reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
else:
|
else:
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
@@ -137,7 +160,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
output = original_question + reply
|
output = original_question + reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -146,47 +169,53 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(encode(original_question)[0])
|
original_tokens = len(encode(original_question)[0])
|
||||||
new_tokens = len(encode(output)[0]) - original_tokens
|
new_tokens = len(encode(output)[0]) - original_tokens
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|
||||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
|
||||||
|
|
||||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
if eos_token is not None:
|
if eos_token is not None:
|
||||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
|
||||||
if type(stopping_strings) is list and len(stopping_strings) > 0:
|
|
||||||
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
|
||||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
|
||||||
|
|
||||||
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
|
# Handling the stopping strings
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
|
for st in [stopping_strings, state['custom_stopping_strings']]:
|
||||||
|
if type(st) is list and len(st) > 0:
|
||||||
|
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||||
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||||
|
break
|
||||||
|
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params["eos_token_id"] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params["stopping_criteria"] = stopping_criteria_list
|
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||||
if shared.args.no_stream:
|
if state['ban_eos_token']:
|
||||||
generate_params["min_length"] = 0
|
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
for k in ["do_sample", "temperature"]:
|
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params["stop"] = generate_state["eos_token_ids"][-1]
|
generate_params['stop'] = state['eos_token_ids'][-1]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
generate_params["max_new_tokens"] = 8
|
generate_params['max_new_tokens'] = 8
|
||||||
|
|
||||||
if shared.args.no_cache:
|
if shared.args.no_cache:
|
||||||
generate_params.update({"use_cache": False})
|
generate_params.update({'use_cache': False})
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
generate_params.update({"synced_gpus": True})
|
generate_params.update({'synced_gpus': True})
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
generate_params.update({"inputs": filler_input_ids})
|
generate_params.update({'inputs': filler_input_ids})
|
||||||
else:
|
else:
|
||||||
generate_params.update({"inputs": input_ids})
|
generate_params.update({'inputs': input_ids})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the entire reply at once.
|
# Generate the entire reply at once.
|
||||||
@@ -201,7 +230,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
new_tokens = len(output) - len(input_ids[0])
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
@@ -228,7 +257,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
new_tokens = len(output) - len(input_ids[0])
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
|
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
@@ -236,7 +265,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
|
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
else:
|
else:
|
||||||
for i in range(generate_state['max_new_tokens']//8+1):
|
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**generate_params)[0]
|
output = shared.model.generate(**generate_params)[0]
|
||||||
@@ -246,7 +275,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
new_tokens = len(output) - len(original_input_ids[0])
|
new_tokens = len(output) - len(original_input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
|
|
||||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||||
break
|
break
|
||||||
@@ -255,10 +284,10 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
generate_params.update({"inputs": filler_input_ids})
|
generate_params.update({'inputs': filler_input_ids})
|
||||||
else:
|
else:
|
||||||
generate_params.update({"inputs": input_ids})
|
generate_params.update({'inputs': input_ids})
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
@@ -268,5 +297,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
new_tokens = len(output) - original_tokens
|
new_tokens = len(output) - original_tokens
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ CURRENT_STEPS = 0
|
|||||||
MAX_STEPS = 0
|
MAX_STEPS = 0
|
||||||
CURRENT_GRADIENT_ACCUM = 1
|
CURRENT_GRADIENT_ACCUM = 1
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(path: str, ext: str):
|
def get_dataset(path: str, ext: str):
|
||||||
return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
|
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def create_train_interface():
|
def create_train_interface():
|
||||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||||
@@ -45,28 +47,34 @@ def create_train_interface():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
|
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||||
|
|
||||||
with gr.Tab(label="Raw Text File"):
|
with gr.Tab(label="Raw Text File"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
|
with gr.Row():
|
||||||
|
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||||
|
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_button = gr.Button("Start LoRA Training")
|
start_button = gr.Button("Start LoRA Training")
|
||||||
stop_button = gr.Button("Interrupt")
|
stop_button = gr.Button("Interrupt")
|
||||||
|
|
||||||
output = gr.Markdown(value="Ready")
|
output = gr.Markdown(value="Ready")
|
||||||
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
|
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
|
||||||
|
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||||
|
|
||||||
|
|
||||||
def do_interrupt():
|
def do_interrupt():
|
||||||
global WANT_INTERRUPT
|
global WANT_INTERRUPT
|
||||||
WANT_INTERRUPT = True
|
WANT_INTERRUPT = True
|
||||||
|
|
||||||
|
|
||||||
class Callbacks(transformers.TrainerCallback):
|
class Callbacks(transformers.TrainerCallback):
|
||||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
global CURRENT_STEPS, MAX_STEPS
|
global CURRENT_STEPS, MAX_STEPS
|
||||||
@@ -75,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
|
|||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||||
global CURRENT_STEPS
|
global CURRENT_STEPS
|
||||||
CURRENT_STEPS += 1
|
CURRENT_STEPS += 1
|
||||||
@@ -82,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
|
||||||
def clean_path(base_path: str, path: str):
|
def clean_path(base_path: str, path: str):
|
||||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||||
@@ -91,8 +101,9 @@ def clean_path(base_path: str, path: str):
|
|||||||
return path
|
return path
|
||||||
return f'{Path(base_path).absolute()}/{path}'
|
return f'{Path(base_path).absolute()}/{path}'
|
||||||
|
|
||||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
|
|
||||||
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
|
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||||
|
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
CURRENT_STEPS = 0
|
CURRENT_STEPS = 0
|
||||||
@@ -103,6 +114,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
|
||||||
actual_lr = float(learning_rate)
|
actual_lr = float(learning_rate)
|
||||||
|
|
||||||
|
model_type = type(shared.model).__name__
|
||||||
|
if model_type != "LlamaForCausalLM":
|
||||||
|
if model_type == "PeftModelForCausalLM":
|
||||||
|
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
|
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||||
|
else:
|
||||||
|
yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||||
|
print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
|
||||||
|
yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
|
||||||
|
return
|
||||||
|
|
||||||
|
elif not shared.args.load_in_8bit:
|
||||||
|
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||||
|
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||||
|
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||||
|
|
||||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||||
yield "Cannot input zeroes."
|
yield "Cannot input zeroes."
|
||||||
return
|
return
|
||||||
@@ -122,19 +152,24 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
print("Loading raw text file dataset...")
|
print("Loading raw text file dataset...")
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||||
raw_text = file.read()
|
raw_text = file.read()
|
||||||
tokens = shared.tokenizer.encode(raw_text)
|
tokens = shared.tokenizer.encode(raw_text)
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
|
|
||||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||||
for i in range(1, len(tokens)):
|
for i in range(1, len(tokens)):
|
||||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||||
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
||||||
del tokens
|
del tokens
|
||||||
data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
|
||||||
train_data = data.shuffle()
|
if newline_favor_len > 0:
|
||||||
eval_data = None
|
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||||
|
|
||||||
|
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
|
||||||
del text_chunks
|
del text_chunks
|
||||||
|
train_data = train_data.shuffle()
|
||||||
|
eval_data = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if dataset in ['None', '']:
|
if dataset in ['None', '']:
|
||||||
@@ -203,7 +238,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
warmup_steps=100,
|
warmup_steps=100,
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
learning_rate=actual_lr,
|
learning_rate=actual_lr,
|
||||||
fp16=True,
|
fp16=False if shared.args.cpu else True,
|
||||||
logging_steps=20,
|
logging_steps=20,
|
||||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
@@ -213,7 +248,8 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
load_best_model_at_end=True if eval_data is not None else False,
|
load_best_model_at_end=True if eval_data is not None else False,
|
||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None
|
ddp_find_unused_parameters=None,
|
||||||
|
no_cuda=shared.args.cpu
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
callbacks=list([Callbacks()])
|
callbacks=list([Callbacks()])
|
||||||
@@ -232,33 +268,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
# TODO: save/load checkpoints to resume from?
|
# TODO: save/load checkpoints to resume from?
|
||||||
print("Starting training...")
|
print("Starting training...")
|
||||||
yield "Starting..."
|
yield "Starting..."
|
||||||
|
if WANT_INTERRUPT:
|
||||||
|
yield "Interrupted before start."
|
||||||
|
return
|
||||||
|
|
||||||
def threadedRun():
|
def threaded_run():
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
thread = threading.Thread(target=threadedRun)
|
thread = threading.Thread(target=threaded_run)
|
||||||
thread.start()
|
thread.start()
|
||||||
lastStep = 0
|
last_step = 0
|
||||||
startTime = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
|
||||||
elif CURRENT_STEPS != lastStep:
|
|
||||||
lastStep = CURRENT_STEPS
|
elif CURRENT_STEPS != last_step:
|
||||||
timeElapsed = time.perf_counter() - startTime
|
last_step = CURRENT_STEPS
|
||||||
if timeElapsed <= 0:
|
time_elapsed = time.perf_counter() - start_time
|
||||||
timerInfo = ""
|
if time_elapsed <= 0:
|
||||||
totalTimeEstimate = 999
|
timer_info = ""
|
||||||
|
total_time_estimate = 999
|
||||||
else:
|
else:
|
||||||
its = CURRENT_STEPS / timeElapsed
|
its = CURRENT_STEPS / time_elapsed
|
||||||
if its > 1:
|
if its > 1:
|
||||||
timerInfo = f"`{its:.2f}` it/s"
|
timer_info = f"`{its:.2f}` it/s"
|
||||||
else:
|
else:
|
||||||
timerInfo = f"`{1.0/its:.2f}` s/it"
|
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||||
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
|
total_time_estimate = (1.0 / its) * (MAX_STEPS)
|
||||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
|
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||||
|
|
||||||
print("Training complete, saving...")
|
print("Training complete, saving...")
|
||||||
lora_model.save_pretrained(lora_name)
|
lora_model.save_pretrained(lora_name)
|
||||||
@@ -270,6 +310,31 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
print("Training complete!")
|
print("Training complete!")
|
||||||
yield f"Done! LoRA saved to `{lora_name}`"
|
yield f"Done! LoRA saved to `{lora_name}`"
|
||||||
|
|
||||||
|
|
||||||
def split_chunks(arr, step):
|
def split_chunks(arr, step):
|
||||||
for i in range(0, len(arr), step):
|
for i in range(0, len(arr), step):
|
||||||
yield arr[i:i + step]
|
yield arr[i:i + step]
|
||||||
|
|
||||||
|
|
||||||
|
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||||
|
if '\n' not in chunk:
|
||||||
|
return chunk
|
||||||
|
first_newline = chunk.index('\n')
|
||||||
|
if first_newline < max_length:
|
||||||
|
chunk = chunk[first_newline + 1:]
|
||||||
|
if '\n' not in chunk:
|
||||||
|
return chunk
|
||||||
|
last_newline = chunk.rindex('\n')
|
||||||
|
if len(chunk) - last_newline < max_length:
|
||||||
|
chunk = chunk[:last_newline]
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
def format_time(seconds: float):
|
||||||
|
if seconds < 120:
|
||||||
|
return f"`{seconds:.0f}` seconds"
|
||||||
|
minutes = seconds / 60
|
||||||
|
if minutes < 120:
|
||||||
|
return f"`{minutes:.0f}` minutes"
|
||||||
|
hours = minutes / 60
|
||||||
|
return f"`{hours:.0f}` hours"
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
|||||||
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||||
chat_js = f.read()
|
chat_js = f.read()
|
||||||
|
|
||||||
|
|
||||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||||
|
|
||||||
@@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
|||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "button"
|
return "button"
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_method()
|
refresh_method()
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
do_sample=True
|
do_sample=True
|
||||||
top_p=0.5
|
top_p=0.95
|
||||||
top_k=40
|
top_k=50
|
||||||
temperature=0.7
|
temperature=1
|
||||||
repetition_penalty=1.2
|
repetition_penalty=1.2
|
||||||
typical_p=1.0
|
typical_p=1.0
|
||||||
early_stopping=False
|
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
accelerate==0.18.0
|
accelerate==0.18.0
|
||||||
bitsandbytes==0.37.2
|
|
||||||
datasets
|
datasets
|
||||||
flexgen==0.1.7
|
flexgen==0.1.7
|
||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
llamacpp==0.1.11
|
|
||||||
markdown
|
markdown
|
||||||
numpy
|
numpy
|
||||||
|
Pillow>=9.5.0
|
||||||
peft==0.2.0
|
peft==0.2.0
|
||||||
requests
|
requests
|
||||||
rwkv==0.7.2
|
rwkv==0.7.3
|
||||||
safetensors==0.3.0
|
safetensors==0.3.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
pyyaml
|
pyyaml
|
||||||
tqdm
|
tqdm
|
||||||
git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0
|
git+https://github.com/huggingface/transformers
|
||||||
|
bitsandbytes==0.37.2; platform_system != "Windows"
|
||||||
|
llama-cpp-python==0.1.30; platform_system != "Windows"
|
||||||
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
|
|||||||
376
server.py
376
server.py
@@ -2,11 +2,14 @@ import os
|
|||||||
|
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
|
import importlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -15,12 +18,12 @@ import gradio as gr
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
from modules import chat, shared, training, ui, api
|
from modules import api, chat, shared, training, ui
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt, unload_model
|
||||||
from modules.text_generation import (clear_torch_cache, generate_reply,
|
from modules.text_generation import generate_reply, stop_everything_event
|
||||||
stop_everything_event)
|
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
@@ -34,15 +37,18 @@ if settings_file is not None:
|
|||||||
for item in new_settings:
|
for item in new_settings:
|
||||||
shared.settings[item] = new_settings[item]
|
shared.settings[item] = new_settings[item]
|
||||||
|
|
||||||
|
|
||||||
def get_available_models():
|
def get_available_models():
|
||||||
if shared.args.flexgen:
|
if shared.args.flexgen:
|
||||||
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
||||||
else:
|
else:
|
||||||
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_presets():
|
def get_available_presets():
|
||||||
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
|
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_prompts():
|
def get_available_prompts():
|
||||||
prompts = []
|
prompts = []
|
||||||
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
||||||
@@ -50,26 +56,31 @@ def get_available_prompts():
|
|||||||
prompts += ['None']
|
prompts += ['None']
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
def get_available_characters():
|
def get_available_characters():
|
||||||
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_instruction_templates():
|
def get_available_instruction_templates():
|
||||||
paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
path = "characters/instruction-following"
|
||||||
|
paths = []
|
||||||
|
if os.path.exists(path):
|
||||||
|
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_extensions():
|
def get_available_extensions():
|
||||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_softprompts():
|
def get_available_softprompts():
|
||||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
|
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def get_available_loras():
|
def get_available_loras():
|
||||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||||
|
|
||||||
def unload_model():
|
|
||||||
shared.model = shared.tokenizer = None
|
|
||||||
clear_torch_cache()
|
|
||||||
|
|
||||||
def load_model_wrapper(selected_model):
|
def load_model_wrapper(selected_model):
|
||||||
if selected_model != shared.model_name:
|
if selected_model != shared.model_name:
|
||||||
@@ -81,10 +92,12 @@ def load_model_wrapper(selected_model):
|
|||||||
|
|
||||||
return selected_model
|
return selected_model
|
||||||
|
|
||||||
|
|
||||||
def load_lora_wrapper(selected_lora):
|
def load_lora_wrapper(selected_lora):
|
||||||
add_lora_to_model(selected_lora)
|
add_lora_to_model(selected_lora)
|
||||||
return selected_lora
|
return selected_lora
|
||||||
|
|
||||||
|
|
||||||
def load_preset_values(preset_menu, state, return_dict=False):
|
def load_preset_values(preset_menu, state, return_dict=False):
|
||||||
generate_params = {
|
generate_params = {
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
@@ -115,6 +128,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
|||||||
state.update(generate_params)
|
state.update(generate_params)
|
||||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||||
|
|
||||||
|
|
||||||
def upload_soft_prompt(file):
|
def upload_soft_prompt(file):
|
||||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||||
zf.extract('meta.json')
|
zf.extract('meta.json')
|
||||||
@@ -127,16 +141,6 @@ def upload_soft_prompt(file):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def create_model_and_preset_menus():
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
|
||||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
|
||||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
|
||||||
|
|
||||||
def save_prompt(text):
|
def save_prompt(text):
|
||||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
||||||
@@ -144,6 +148,7 @@ def save_prompt(text):
|
|||||||
f.write(text)
|
f.write(text)
|
||||||
return f"Saved to prompts/{fname}"
|
return f"Saved to prompts/{fname}"
|
||||||
|
|
||||||
|
|
||||||
def load_prompt(fname):
|
def load_prompt(fname):
|
||||||
if fname in ['None', '']:
|
if fname in ['None', '']:
|
||||||
return ''
|
return ''
|
||||||
@@ -154,6 +159,7 @@ def load_prompt(fname):
|
|||||||
text = text[:-1]
|
text = text[:-1]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def create_prompt_menus():
|
def create_prompt_menus():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -169,40 +175,95 @@ def create_prompt_menus():
|
|||||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_wrapper(repo_id):
|
||||||
|
try:
|
||||||
|
downloader = importlib.import_module("download-model")
|
||||||
|
|
||||||
|
model = repo_id
|
||||||
|
branch = "main"
|
||||||
|
check = False
|
||||||
|
|
||||||
|
yield ("Cleaning up the model/branch names")
|
||||||
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
||||||
|
yield ("Getting the download links from Hugging Face")
|
||||||
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
|
|
||||||
|
yield ("Getting the output folder")
|
||||||
|
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||||
|
|
||||||
|
if check:
|
||||||
|
yield ("Checking previously downloaded files")
|
||||||
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
|
else:
|
||||||
|
yield (f"Downloading files to {output_folder}")
|
||||||
|
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||||
|
yield ("Done!")
|
||||||
|
except:
|
||||||
|
yield traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_menus():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
||||||
|
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||||
|
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
|
||||||
|
info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['download_button'] = gr.Button("Download")
|
||||||
|
shared.gradio['download_status'] = gr.Markdown()
|
||||||
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
|
||||||
|
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||||
|
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||||
|
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
|
|
||||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
|
||||||
generate_params[k] = shared.settings[k]
|
|
||||||
shared.gradio['generate_state'] = gr.State(generate_params)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_model_and_preset_menus()
|
with gr.Row():
|
||||||
|
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||||
|
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
|
||||||
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
|
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
|
||||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
|
||||||
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
|
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
|
||||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
|
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
|
||||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
|
||||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
|
||||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Contrastive search')
|
gr.Markdown('Contrastive search')
|
||||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||||
|
|
||||||
with gr.Box():
|
|
||||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -211,9 +272,12 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='This forces the model to never end the generation prematurely.')
|
||||||
|
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||||
|
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||||
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
with gr.Accordion('Soft prompt', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -224,12 +288,11 @@ def create_settings_menus(default_preset):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||||
|
|
||||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
|
||||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
|
||||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||||
|
|
||||||
|
|
||||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||||
cmd_list = vars(shared.args)
|
cmd_list = vars(shared.args)
|
||||||
@@ -248,6 +311,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
|
|||||||
|
|
||||||
shared.need_restart = True
|
shared.need_restart = True
|
||||||
|
|
||||||
|
|
||||||
available_models = get_available_models()
|
available_models = get_available_models()
|
||||||
available_presets = get_available_presets()
|
available_presets = get_available_presets()
|
||||||
available_characters = get_available_characters()
|
available_characters = get_available_characters()
|
||||||
@@ -296,35 +360,57 @@ else:
|
|||||||
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
||||||
title = 'Text generation web UI'
|
title = 'Text generation web UI'
|
||||||
|
|
||||||
def create_interface():
|
|
||||||
|
|
||||||
|
def list_interface_input_elements(chat=False):
|
||||||
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||||
|
if chat:
|
||||||
|
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
def gather_interface_values(*args):
|
||||||
|
output = {}
|
||||||
|
for i, element in enumerate(shared.input_elements):
|
||||||
|
output[element] = args[i]
|
||||||
|
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def create_interface():
|
||||||
gen_events = []
|
gen_events = []
|
||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
extensions_module.load_extensions()
|
extensions_module.load_extensions()
|
||||||
|
|
||||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||||
if shared.is_chat():
|
if shared.is_chat():
|
||||||
|
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=True)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
|
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Generate'] = gr.Button('Generate')
|
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
|
||||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||||
|
shared.gradio['Continue'] = gr.Button('Continue')
|
||||||
|
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
|
||||||
|
shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
|
||||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||||
shared.gradio['Remove last'] = gr.Button('Remove last')
|
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||||
|
with gr.Row():
|
||||||
shared.gradio['Clear history'] = gr.Button('Clear history')
|
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||||
|
shared.gradio['Remove last'] = gr.Button('Remove last')
|
||||||
|
|
||||||
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
|
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
|
||||||
|
|
||||||
with gr.Tab("Character", elem_id="chat-settings"):
|
with gr.Tab("Character", elem_id="chat-settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -371,66 +457,102 @@ def create_interface():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||||
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
|
||||||
|
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
|
||||||
|
|
||||||
def set_chat_input(textbox):
|
|
||||||
return textbox, ""
|
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
|
||||||
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
|
||||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
|
||||||
|
|
||||||
# Clear history with confirmation
|
|
||||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||||
|
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Impersonate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Replace last reply'].click(
|
||||||
|
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Send dummy message'].click(
|
||||||
|
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Send dummy reply'].click(
|
||||||
|
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Clear history-confirm'].click(
|
||||||
|
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||||
|
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(
|
||||||
|
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['mode'].change(
|
||||||
|
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
|
||||||
|
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['Instruction templates'].change(
|
||||||
|
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['upload_chat_history'].upload(
|
||||||
|
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
|
||||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
|
|
||||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
|
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||||
|
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
|
||||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
|
||||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
|
||||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||||
|
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||||
# Clearing stuff and saving the history
|
|
||||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
|
||||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
|
||||||
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
|
||||||
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
|
||||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
|
||||||
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
|
||||||
|
|
||||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
|
||||||
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
|
||||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
|
|
||||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
|
|
||||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
|
||||||
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
|
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=False)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
@@ -458,14 +580,27 @@ def create_interface():
|
|||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=False)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -491,14 +626,33 @@ def create_interface():
|
|||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
|
with gr.Tab("Model", elem_id="model-tab"):
|
||||||
|
create_model_menus()
|
||||||
|
|
||||||
with gr.Tab("Training", elem_id="training-tab"):
|
with gr.Tab("Training", elem_id="training-tab"):
|
||||||
training.create_train_interface()
|
training.create_train_interface()
|
||||||
|
|
||||||
@@ -512,32 +666,21 @@ def create_interface():
|
|||||||
cmd_list = vars(shared.args)
|
cmd_list = vars(shared.args)
|
||||||
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||||
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
bool_active = [k for k in bool_list if vars(shared.args)[k]]
|
||||||
#int_list = [k for k in cmd_list if type(k) is int]
|
|
||||||
|
|
||||||
gr.Markdown("*Experimental*")
|
gr.Markdown("*Experimental*")
|
||||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
||||||
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
|
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
|
||||||
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
|
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
|
||||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
|
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
|
||||||
|
|
||||||
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
|
# Reset interface event
|
||||||
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
shared.gradio['reset_interface'].click(
|
||||||
|
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
|
||||||
|
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||||
|
|
||||||
if shared.args.extensions is not None:
|
if shared.args.extensions is not None:
|
||||||
extensions_module.create_extensions_block()
|
extensions_module.create_extensions_block()
|
||||||
|
|
||||||
def change_dict_value(d, key, value):
|
|
||||||
d[key] = value
|
|
||||||
return d
|
|
||||||
|
|
||||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
|
||||||
if k not in shared.gradio:
|
|
||||||
continue
|
|
||||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
|
||||||
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
|
||||||
else:
|
|
||||||
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
|
||||||
|
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
api.create_apis()
|
api.create_apis()
|
||||||
|
|
||||||
@@ -557,6 +700,7 @@ def create_interface():
|
|||||||
else:
|
else:
|
||||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
|
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
|
||||||
|
|
||||||
|
|
||||||
create_interface()
|
create_interface()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -7,7 +7,14 @@
|
|||||||
"name2": "Assistant",
|
"name2": "Assistant",
|
||||||
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
|
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
|
||||||
"greeting": "Hello there!",
|
"greeting": "Hello there!",
|
||||||
|
"end_of_turn": "",
|
||||||
|
"custom_stopping_strings": "",
|
||||||
"stop_at_newline": false,
|
"stop_at_newline": false,
|
||||||
|
"add_bos_token": true,
|
||||||
|
"ban_eos_token": true,
|
||||||
|
"truncation_length": 2048,
|
||||||
|
"truncation_length_min": 0,
|
||||||
|
"truncation_length_max": 4096,
|
||||||
"chat_prompt_size": 2048,
|
"chat_prompt_size": 2048,
|
||||||
"chat_prompt_size_min": 0,
|
"chat_prompt_size_min": 0,
|
||||||
"chat_prompt_size_max": 2048,
|
"chat_prompt_size_max": 2048,
|
||||||
@@ -19,7 +26,8 @@
|
|||||||
"gallery"
|
"gallery"
|
||||||
],
|
],
|
||||||
"presets": {
|
"presets": {
|
||||||
"default": "NovelAI-Sphinx Moth",
|
"default": "Default",
|
||||||
|
".*(alpaca|llama)": "LLaMA-Precise",
|
||||||
".*pygmalion": "NovelAI-Storywriter",
|
".*pygmalion": "NovelAI-Storywriter",
|
||||||
".*RWKV": "Naive"
|
".*RWKV": "Naive"
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user