Compare commits
92 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49ce866c99 | ||
|
|
ff610b47d2 | ||
|
|
3850f13624 | ||
|
|
461ca7faf5 | ||
|
|
832ee4323d | ||
|
|
1405cd8af2 | ||
|
|
2289d3686f | ||
|
|
61641a4551 | ||
|
|
f2be87235d | ||
|
|
8265d45db8 | ||
|
|
37d52c96bc | ||
|
|
f2ec880e81 | ||
|
|
f34f2daa3d | ||
|
|
cacbcda208 | ||
|
|
749c08a4ff | ||
|
|
e9e93189ff | ||
|
|
dc3c9d00a0 | ||
|
|
457d3c58eb | ||
|
|
78bbc66fc4 | ||
|
|
0f212093a3 | ||
|
|
64f5c90ee7 | ||
|
|
58b34c0841 | ||
|
|
5234071c04 | ||
|
|
09d8119e3c | ||
|
|
0caf718a21 | ||
|
|
85a7954823 | ||
|
|
d37b4f76b1 | ||
|
|
bd04ff27ad | ||
|
|
f035b01823 | ||
|
|
b7ca89ba3f | ||
|
|
52339e9b20 | ||
|
|
4961f43702 | ||
|
|
617530296e | ||
|
|
0f1627eff1 | ||
|
|
d679c4be13 | ||
|
|
45244ed125 | ||
|
|
7e70741a4e | ||
|
|
11b23db8d4 | ||
|
|
2c14df81a8 | ||
|
|
c6e9ba20a4 | ||
|
|
843f672227 | ||
|
|
769aa900ea | ||
|
|
32d078487e | ||
|
|
30befe492a | ||
|
|
1911504f82 | ||
|
|
8178fde2cb | ||
|
|
dba2000d2b | ||
|
|
65552d2157 | ||
|
|
8c6155251a | ||
|
|
992663fa20 | ||
|
|
625d81f495 | ||
|
|
57f768eaad | ||
|
|
a3085dba07 | ||
|
|
120f5662cf | ||
|
|
b27d757fd1 | ||
|
|
d29f4624e9 | ||
|
|
170e0c05c4 | ||
|
|
34ec02d41d | ||
|
|
f91d3a3ff4 | ||
|
|
ebdf4c8c12 | ||
|
|
7436dd5b4a | ||
|
|
bce1b7fbb2 | ||
|
|
f7860ce192 | ||
|
|
ece8ed2c84 | ||
|
|
cc693a7546 | ||
|
|
2fde50a800 | ||
|
|
acc235aced | ||
|
|
df561fd896 | ||
|
|
d272ac46dd | ||
|
|
cb169d0834 | ||
|
|
2f16d0afca | ||
|
|
a6a00cb82f | ||
|
|
c97c270040 | ||
|
|
0b458bf82d | ||
|
|
ffd102e5c0 | ||
|
|
5543a5089d | ||
|
|
1dc464dcb0 | ||
|
|
962e33dc10 | ||
|
|
42ea6a3fc0 | ||
|
|
e563b015d8 | ||
|
|
1c413ed593 | ||
|
|
3f922d4bfb | ||
|
|
744bf7cbf2 | ||
|
|
768354239b | ||
|
|
6762e62a40 | ||
|
|
a453d4e9c4 | ||
|
|
ec979cd9c4 | ||
|
|
2c0018d946 | ||
|
|
8fa182cfa7 | ||
|
|
862aad637b | ||
|
|
46c4654226 | ||
|
|
ea6e77df72 |
@@ -1,7 +1,6 @@
|
|||||||
.env
|
.env
|
||||||
Dockerfile
|
Dockerfile
|
||||||
/characters
|
/characters
|
||||||
/extensions
|
|
||||||
/loras
|
/loras
|
||||||
/models
|
/models
|
||||||
/presets
|
/presets
|
||||||
|
|||||||
21
Dockerfile
21
Dockerfile
@@ -26,12 +26,11 @@ LABEL maintainer="Your Name <your.email@example.com>"
|
|||||||
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
|
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install --no-install-recommends -y git python3 python3-pip && \
|
apt-get install --no-install-recommends -y git python3 python3-pip make g++ && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
|
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
|
||||||
|
RUN mkdir /app
|
||||||
COPY . /app/
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
@@ -41,21 +40,29 @@ RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Usi
|
|||||||
RUN virtualenv /app/venv
|
RUN virtualenv /app/venv
|
||||||
RUN . /app/venv/bin/activate && \
|
RUN . /app/venv/bin/activate && \
|
||||||
pip3 install --upgrade pip setuptools && \
|
pip3 install --upgrade pip setuptools && \
|
||||||
pip3 install torch torchvision torchaudio && \
|
pip3 install torch torchvision torchaudio
|
||||||
pip3 install -r requirements.txt
|
|
||||||
|
|
||||||
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
|
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
|
||||||
RUN . /app/venv/bin/activate && \
|
RUN . /app/venv/bin/activate && \
|
||||||
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
|
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
|
||||||
|
|
||||||
ENV CLI_ARGS=""
|
COPY extensions/api/requirements.txt /app/extensions/api/requirements.txt
|
||||||
|
COPY extensions/elevenlabs_tts/requirements.txt /app/extensions/elevenlabs_tts/requirements.txt
|
||||||
|
COPY extensions/google_translate/requirements.txt /app/extensions/google_translate/requirements.txt
|
||||||
|
COPY extensions/silero_tts/requirements.txt /app/extensions/silero_tts/requirements.txt
|
||||||
|
COPY extensions/whisper_stt/requirements.txt /app/extensions/whisper_stt/requirements.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
|
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
|
||||||
|
|
||||||
|
COPY requirements.txt /app/requirements.txt
|
||||||
|
RUN . /app/venv/bin/activate && \
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
|
||||||
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
|
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
ENV CLI_ARGS=""
|
||||||
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}
|
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}
|
||||||
|
|||||||
49
README.md
49
README.md
@@ -1,11 +1,9 @@
|
|||||||
# Text generation web UI
|
# Text generation web UI
|
||||||
|
|
||||||
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, OPT, and GALACTICA.
|
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.
|
||||||
|
|
||||||
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
||||||
|
|
||||||
[[Try it on Google Colab]](https://colab.research.google.com/github/oobabooga/AI-Notebooks/blob/main/Colab-TextGen-GPU.ipynb)
|
|
||||||
|
|
||||||
| |  |
|
| |  |
|
||||||
|:---:|:---:|
|
|:---:|:---:|
|
||||||
| |  |
|
| |  |
|
||||||
@@ -15,7 +13,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* Dropdown menu for switching between models
|
* Dropdown menu for switching between models
|
||||||
* Notebook mode that resembles OpenAI's playground
|
* Notebook mode that resembles OpenAI's playground
|
||||||
* Chat mode for conversation and role playing
|
* Chat mode for conversation and role playing
|
||||||
* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
|
* Instruct mode compatible with Alpaca, Vicuna, and Open Assistant formats **\*NEW!\***
|
||||||
* Nice HTML output for GPT-4chan
|
* Nice HTML output for GPT-4chan
|
||||||
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
|
||||||
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
|
||||||
@@ -34,7 +32,6 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
|||||||
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
||||||
* Softprompts
|
* Softprompts
|
||||||
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
||||||
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -73,9 +70,15 @@ On Linux or WSL, it can be automatically installed with these two commands:
|
|||||||
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
|
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
|
||||||
bash Miniconda3.sh
|
bash Miniconda3.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Source: https://educe-ubc.github.io/conda.html
|
Source: https://educe-ubc.github.io/conda.html
|
||||||
|
|
||||||
|
#### 0.1 (Ubuntu/WSL) Install build tools
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo apt install build-essential
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### 1. Create a new conda environment
|
#### 1. Create a new conda environment
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -119,7 +122,7 @@ As an alternative to the recommended WSL method, you can install the web UI nati
|
|||||||
|
|
||||||
```
|
```
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
docker-compose up --build
|
docker compose up --build
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
|
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
|
||||||
@@ -192,25 +195,25 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### Basic settings
|
#### Basic settings
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|--------------------------------------------|-------------|
|
||||||
| `-h`, `--help` | show this help message and exit |
|
| `-h`, `--help` | Show this help message and exit. |
|
||||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||||
| `--chat` | Launch the web UI in chat mode.|
|
| `--chat` | Launch the web UI in chat mode. |
|
||||||
| `--model MODEL` | Name of the model to load by default. |
|
| `--model MODEL` | Name of the model to load by default. |
|
||||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||||
| `--model-dir MODEL_DIR` | Path to directory with all the models |
|
| `--model-dir MODEL_DIR` | Path to directory with all the models. |
|
||||||
| `--lora-dir LORA_DIR` | Path to directory with all the loras |
|
| `--lora-dir LORA_DIR` | Path to directory with all the loras. |
|
||||||
| `--no-stream` | Don't stream the text output in real time. |
|
| `--no-stream` | Don't stream the text output in real time. |
|
||||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag. |
|
||||||
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
||||||
| `--verbose` | Print the prompts to the terminal. |
|
| `--verbose` | Print the prompts to the terminal. |
|
||||||
|
|
||||||
#### Accelerate/transformers
|
#### Accelerate/transformers
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------------------|-------------|
|
||||||
| `--cpu` | Use the CPU to generate text.|
|
| `--cpu` | Use the CPU to generate text. Warning: Training on CPU is extremely slow.|
|
||||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
|
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. |
|
||||||
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
||||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
||||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||||
@@ -218,17 +221,19 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||||
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
||||||
|
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
|
||||||
|
| `--sdp-attention` | Use torch 2.0's sdp attention. |
|
||||||
|
|
||||||
#### llama.cpp
|
#### llama.cpp
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|-------------|-------------|
|
||||||
| `--threads` | Number of threads to use in llama.cpp. |
|
| `--threads` | Number of threads to use in llama.cpp. |
|
||||||
|
|
||||||
#### GPTQ
|
#### GPTQ
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------|-------------|
|
||||||
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
||||||
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||||
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
||||||
@@ -246,7 +251,7 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### DeepSpeed
|
#### DeepSpeed
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------------|-------------|
|
||||||
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
||||||
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
||||||
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
||||||
@@ -254,14 +259,14 @@ Optionally, you can use the following command-line flags:
|
|||||||
#### RWKV
|
#### RWKV
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------|-------------|
|
||||||
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
||||||
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
||||||
|
|
||||||
#### Gradio
|
#### Gradio
|
||||||
|
|
||||||
| Flag | Description |
|
| Flag | Description |
|
||||||
|------------------|-------------|
|
|---------------------------------------|-------------|
|
||||||
| `--listen` | Make the web UI reachable from your local network. |
|
| `--listen` | Make the web UI reachable from your local network. |
|
||||||
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
||||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||||
@@ -286,6 +291,8 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
|
|||||||
|
|
||||||
Pull requests, suggestions, and issue reports are welcome.
|
Pull requests, suggestions, and issue reports are welcome.
|
||||||
|
|
||||||
|
You are also welcome to review open pull requests.
|
||||||
|
|
||||||
Before reporting a bug, make sure that you have:
|
Before reporting a bug, make sure that you have:
|
||||||
|
|
||||||
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
|
1. Created a conda environment and installed the dependencies exactly as in the *Installation* section above.
|
||||||
|
|||||||
@@ -12,6 +12,11 @@ import string
|
|||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
|
# Note, Gradio may pick a different fn value as the definition of the Gradio app changes.
|
||||||
|
# You can always launch the web UI and inspect the websocket stream using your browser's dev tools
|
||||||
|
# to determine what value Gradio expects here.
|
||||||
|
GRADIO_FN = 29
|
||||||
|
|
||||||
|
|
||||||
def random_hash():
|
def random_hash():
|
||||||
letters = string.ascii_lowercase + string.digits
|
letters = string.ascii_lowercase + string.digits
|
||||||
@@ -36,6 +41,10 @@ async def run(context):
|
|||||||
'length_penalty': 1,
|
'length_penalty': 1,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'seed': -1,
|
'seed': -1,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'custom_stopping_strings': [],
|
||||||
|
'ban_eos_token': False
|
||||||
}
|
}
|
||||||
payload = json.dumps([context, params])
|
payload = json.dumps([context, params])
|
||||||
session = random_hash()
|
session = random_hash()
|
||||||
@@ -47,14 +56,14 @@ async def run(context):
|
|||||||
case "send_hash":
|
case "send_hash":
|
||||||
await websocket.send(json.dumps({
|
await websocket.send(json.dumps({
|
||||||
"session_hash": session,
|
"session_hash": session,
|
||||||
"fn_index": 12
|
"fn_index": GRADIO_FN
|
||||||
}))
|
}))
|
||||||
case "estimation":
|
case "estimation":
|
||||||
pass
|
pass
|
||||||
case "send_data":
|
case "send_data":
|
||||||
await websocket.send(json.dumps({
|
await websocket.send(json.dumps({
|
||||||
"session_hash": session,
|
"session_hash": session,
|
||||||
"fn_index": 12,
|
"fn_index": GRADIO_FN,
|
||||||
"data": [
|
"data": [
|
||||||
payload
|
payload
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,10 +22,10 @@ server = "127.0.0.1"
|
|||||||
params = {
|
params = {
|
||||||
'max_new_tokens': 200,
|
'max_new_tokens': 200,
|
||||||
'do_sample': True,
|
'do_sample': True,
|
||||||
'temperature': 0.5,
|
'temperature': 0.72,
|
||||||
'top_p': 0.9,
|
'top_p': 0.73,
|
||||||
'typical_p': 1,
|
'typical_p': 1,
|
||||||
'repetition_penalty': 1.05,
|
'repetition_penalty': 1.1,
|
||||||
'encoder_repetition_penalty': 1.0,
|
'encoder_repetition_penalty': 1.0,
|
||||||
'top_k': 0,
|
'top_k': 0,
|
||||||
'min_length': 0,
|
'min_length': 0,
|
||||||
@@ -35,6 +35,10 @@ params = {
|
|||||||
'length_penalty': 1,
|
'length_penalty': 1,
|
||||||
'early_stopping': False,
|
'early_stopping': False,
|
||||||
'seed': -1,
|
'seed': -1,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'custom_stopping_strings': [],
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'ban_eos_token': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Input prompt
|
# Input prompt
|
||||||
|
|||||||
@@ -36,3 +36,8 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
|||||||
.wrap.svelte-6roggh.svelte-6roggh {
|
.wrap.svelte-6roggh.svelte-6roggh {
|
||||||
max-height: 92.5%;
|
max-height: 92.5%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* This is for the microphone button in the whisper extension */
|
||||||
|
.sm.svelte-1ipelgc {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,11 +7,13 @@
|
|||||||
padding-right: 20px;
|
padding-right: 20px;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
|
word-break: break-word;
|
||||||
|
overflow-wrap: anywhere;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: 60px 1fr;
|
grid-template-columns: 60px minmax(0, 1fr);
|
||||||
padding-bottom: 25px;
|
padding-bottom: 25px;
|
||||||
font-size: 15px;
|
font-size: 15px;
|
||||||
font-family: Helvetica, Arial, sans-serif;
|
font-family: Helvetica, Arial, sans-serif;
|
||||||
@@ -73,6 +75,13 @@
|
|||||||
display: inline !important;
|
display: inline !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body code {
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
.message-body :not(pre) > code {
|
||||||
|
white-space: normal !important;
|
||||||
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
color: rgb(138, 138, 138) !important;
|
color: rgb(138, 138, 138) !important;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
padding-right: 20px;
|
padding-right: 20px;
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column-reverse;
|
flex-direction: column-reverse;
|
||||||
|
word-break: break-word;
|
||||||
|
overflow-wrap: anywhere;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
@@ -25,9 +27,7 @@
|
|||||||
.message-body {}
|
.message-body {}
|
||||||
|
|
||||||
.message-body p {
|
.message-body p {
|
||||||
margin-bottom: 0 !important;
|
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
line-height: 1.428571429 !important;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body li {
|
.message-body li {
|
||||||
@@ -39,6 +39,13 @@
|
|||||||
display: inline !important;
|
display: inline !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body code {
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
.message-body :not(pre) > code {
|
||||||
|
white-space: normal !important;
|
||||||
|
}
|
||||||
|
|
||||||
.dark .message-body p em {
|
.dark .message-body p em {
|
||||||
color: rgb(138, 138, 138) !important;
|
color: rgb(138, 138, 138) !important;
|
||||||
}
|
}
|
||||||
@@ -51,15 +58,16 @@
|
|||||||
padding: 15px;
|
padding: 15px;
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
background-color: #0000000f;
|
background-color: #0000000f;
|
||||||
margin-bottom: 17.5px;
|
margin-top: 9px !important;
|
||||||
|
margin-bottom: 18px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.gradio-container .chat .user-message {
|
.gradio-container .chat .user-message {
|
||||||
padding: 15px;
|
padding: 15px;
|
||||||
border-radius: 20px;
|
border-radius: 20px;
|
||||||
margin-bottom: 17.5px !important;
|
margin-bottom: 9px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.dark .chat .assistant-message {
|
.dark .chat .assistant-message {
|
||||||
background-color: #ffffff21;
|
background-color: #374151;
|
||||||
}
|
}
|
||||||
10
css/main.css
10
css/main.css
@@ -67,3 +67,13 @@ span.math.inline {
|
|||||||
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
||||||
flex-wrap: nowrap;
|
flex-wrap: nowrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.header_bar {
|
||||||
|
background-color: #f7f7f7;
|
||||||
|
margin-bottom: 40px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.dark .header_bar {
|
||||||
|
border: none !important;
|
||||||
|
background-color: #8080802b;
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
|
document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar");
|
||||||
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
|
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
|
||||||
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
|
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ services:
|
|||||||
args:
|
args:
|
||||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
||||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
|
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
|
||||||
GPTQ_VERSION: ${GPTQ_VERSION}
|
|
||||||
WEBUI_VERSION: ${WEBUI_VERSION}
|
WEBUI_VERSION: ${WEBUI_VERSION}
|
||||||
env_file: .env
|
env_file: .env
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
@@ -19,50 +19,6 @@ import requests
|
|||||||
import tqdm
|
import tqdm
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
|
||||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
|
||||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
|
||||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
|
||||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
|
||||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
|
||||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def get_file(url, output_folder):
|
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
|
||||||
output_path = output_folder / filename
|
|
||||||
if output_path.exists() and not args.clean:
|
|
||||||
# Check if the file has already been downloaded completely
|
|
||||||
r = requests.get(url, stream=True)
|
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
|
||||||
if output_path.stat().st_size >= total_size:
|
|
||||||
return
|
|
||||||
# Otherwise, resume the download from where it left off
|
|
||||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
|
||||||
mode = 'ab'
|
|
||||||
else:
|
|
||||||
headers = {}
|
|
||||||
mode = 'wb'
|
|
||||||
|
|
||||||
r = requests.get(url, stream=True, headers=headers)
|
|
||||||
with open(output_path, mode) as f:
|
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
|
||||||
block_size = 1024
|
|
||||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
|
||||||
for data in r.iter_content(block_size):
|
|
||||||
t.update(len(data))
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_branch_name(branch_name):
|
|
||||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
|
||||||
if pattern.match(branch_name):
|
|
||||||
return branch_name
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
|
||||||
|
|
||||||
|
|
||||||
def select_model_from_default_options():
|
def select_model_from_default_options():
|
||||||
models = {
|
models = {
|
||||||
@@ -110,7 +66,20 @@ EleutherAI/pythia-1.4b-deduped
|
|||||||
return model, branch
|
return model, branch
|
||||||
|
|
||||||
|
|
||||||
def get_download_links_from_huggingface(model, branch):
|
def sanitize_model_and_branch_names(model, branch):
|
||||||
|
if model[-1] == '/':
|
||||||
|
model = model[:-1]
|
||||||
|
if branch is None:
|
||||||
|
branch = "main"
|
||||||
|
else:
|
||||||
|
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||||
|
if not pattern.match(branch):
|
||||||
|
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||||
|
|
||||||
|
return model, branch
|
||||||
|
|
||||||
|
|
||||||
|
def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||||
base = "https://huggingface.co"
|
base = "https://huggingface.co"
|
||||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||||
cursor = b""
|
cursor = b""
|
||||||
@@ -142,14 +111,14 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||||
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||||
|
|
||||||
if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
|
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
|
||||||
if 'lfs' in dict[i]:
|
if 'lfs' in dict[i]:
|
||||||
sha256.append([fname, dict[i]['lfs']['oid']])
|
sha256.append([fname, dict[i]['lfs']['oid']])
|
||||||
if is_text:
|
if is_text:
|
||||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||||
classifications.append('text')
|
classifications.append('text')
|
||||||
continue
|
continue
|
||||||
if not args.text_only:
|
if not text_only:
|
||||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||||
if is_safetensors:
|
if is_safetensors:
|
||||||
has_safetensors = True
|
has_safetensors = True
|
||||||
@@ -177,41 +146,67 @@ def get_download_links_from_huggingface(model, branch):
|
|||||||
return links, sha256, is_lora
|
return links, sha256, is_lora
|
||||||
|
|
||||||
|
|
||||||
def download_files(file_list, output_folder, num_threads=8):
|
def get_output_folder(model, branch, is_lora, base_folder=None):
|
||||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
if base_folder is None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
model = args.MODEL
|
|
||||||
branch = args.branch
|
|
||||||
if model is None:
|
|
||||||
model, branch = select_model_from_default_options()
|
|
||||||
else:
|
|
||||||
if model[-1] == '/':
|
|
||||||
model = model[:-1]
|
|
||||||
branch = args.branch
|
|
||||||
if branch is None:
|
|
||||||
branch = "main"
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
branch = sanitize_branch_name(branch)
|
|
||||||
except ValueError as err_branch:
|
|
||||||
print(f"Error: {err_branch}")
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
|
|
||||||
|
|
||||||
if args.output is not None:
|
|
||||||
base_folder = args.output
|
|
||||||
else:
|
|
||||||
base_folder = 'models' if not is_lora else 'loras'
|
base_folder = 'models' if not is_lora else 'loras'
|
||||||
|
|
||||||
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
||||||
if branch != 'main':
|
if branch != 'main':
|
||||||
output_folder += f'_{branch}'
|
output_folder += f'_{branch}'
|
||||||
output_folder = Path(base_folder) / output_folder
|
output_folder = Path(base_folder) / output_folder
|
||||||
|
return output_folder
|
||||||
|
|
||||||
if args.check:
|
|
||||||
|
def get_single_file(url, output_folder, start_from_scratch=False):
|
||||||
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
|
output_path = output_folder / filename
|
||||||
|
if output_path.exists() and not start_from_scratch:
|
||||||
|
# Check if the file has already been downloaded completely
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
if output_path.stat().st_size >= total_size:
|
||||||
|
return
|
||||||
|
# Otherwise, resume the download from where it left off
|
||||||
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
|
mode = 'ab'
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
|
mode = 'wb'
|
||||||
|
|
||||||
|
r = requests.get(url, stream=True, headers=headers)
|
||||||
|
with open(output_path, mode) as f:
|
||||||
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
|
block_size = 1024
|
||||||
|
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||||
|
for data in r.iter_content(block_size):
|
||||||
|
t.update(len(data))
|
||||||
|
f.write(data)
|
||||||
|
|
||||||
|
|
||||||
|
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
|
||||||
|
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
|
||||||
|
# Creating the folder and writing the metadata
|
||||||
|
if not output_folder.exists():
|
||||||
|
output_folder.mkdir()
|
||||||
|
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
||||||
|
f.write(f'url: https://huggingface.co/{model}\n')
|
||||||
|
f.write(f'branch: {branch}\n')
|
||||||
|
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
||||||
|
sha256_str = ''
|
||||||
|
for i in range(len(sha256)):
|
||||||
|
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
||||||
|
if sha256_str != '':
|
||||||
|
f.write(f'sha256sum:\n{sha256_str}')
|
||||||
|
|
||||||
|
# Downloading the files
|
||||||
|
print(f"Downloading the model to {output_folder}")
|
||||||
|
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_files(model, branch, links, sha256, output_folder):
|
||||||
# Validate the checksums
|
# Validate the checksums
|
||||||
validated = True
|
validated = True
|
||||||
for i in range(len(sha256)):
|
for i in range(len(sha256)):
|
||||||
@@ -236,21 +231,40 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||||
|
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||||
|
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||||
|
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||||
|
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||||
|
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||||
|
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
branch = args.branch
|
||||||
|
model = args.MODEL
|
||||||
|
if model is None:
|
||||||
|
model, branch = select_model_from_default_options()
|
||||||
|
|
||||||
|
# Cleaning up the model/branch names
|
||||||
|
try:
|
||||||
|
model, branch = sanitize_model_and_branch_names(model, branch)
|
||||||
|
except ValueError as err_branch:
|
||||||
|
print(f"Error: {err_branch}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# Getting the download links from Hugging Face
|
||||||
|
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
|
||||||
|
|
||||||
|
# Getting the output folder
|
||||||
|
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
|
||||||
|
|
||||||
|
if args.check:
|
||||||
|
# Check previously downloaded files
|
||||||
|
check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
|
# Download files
|
||||||
# Creating the folder and writing the metadata
|
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
|
||||||
if not output_folder.exists():
|
|
||||||
output_folder.mkdir()
|
|
||||||
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
|
||||||
f.write(f'url: https://huggingface.co/{model}\n')
|
|
||||||
f.write(f'branch: {branch}\n')
|
|
||||||
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
|
||||||
sha256_str = ''
|
|
||||||
for i in range(len(sha256)):
|
|
||||||
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
|
||||||
if sha256_str != '':
|
|
||||||
f.write(f'sha256sum:\n{sha256_str}')
|
|
||||||
|
|
||||||
# Downloading the files
|
|
||||||
print(f"Downloading the model to {output_folder}")
|
|
||||||
download_files(links, output_folder, args.threads)
|
|
||||||
|
|||||||
@@ -57,12 +57,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
'length_penalty': float(body.get('length_penalty', 1)),
|
'length_penalty': float(body.get('length_penalty', 1)),
|
||||||
'early_stopping': bool(body.get('early_stopping', False)),
|
'early_stopping': bool(body.get('early_stopping', False)),
|
||||||
'seed': int(body.get('seed', -1)),
|
'seed': int(body.get('seed', -1)),
|
||||||
|
'add_bos_token': int(body.get('add_bos_token', True)),
|
||||||
|
'custom_stopping_strings': body.get('custom_stopping_strings', []),
|
||||||
|
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||||
|
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||||
}
|
}
|
||||||
|
|
||||||
generator = generate_reply(
|
generator = generate_reply(
|
||||||
prompt,
|
prompt,
|
||||||
generate_params,
|
generate_params,
|
||||||
stopping_strings=body.get('stopping_strings', []),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
@@ -78,6 +81,19 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
elif self.path == '/api/v1/token-count':
|
||||||
|
# Not compatible with KoboldAI api
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
tokens = encode(body['prompt'])[0]
|
||||||
|
response = json.dumps({
|
||||||
|
'results': [{
|
||||||
|
'tokens': len(tokens)
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,23 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import os
|
||||||
|
|
||||||
|
# get the current directory of the script
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# check if the bias_options.txt file exists, if not, create it
|
||||||
|
bias_file = os.path.join(current_dir, "bias_options.txt")
|
||||||
|
if not os.path.isfile(bias_file):
|
||||||
|
with open(bias_file, "w") as f:
|
||||||
|
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
|
||||||
|
|
||||||
|
# read bias options from the text file
|
||||||
|
with open(bias_file, "r") as f:
|
||||||
|
bias_options = [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"activate": True,
|
"activate": True,
|
||||||
"bias string": " *I am so happy*",
|
"bias string": " *I am so happy*",
|
||||||
|
"use custom string": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -11,7 +26,6 @@ def input_modifier(string):
|
|||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
they are fed into the model.
|
they are fed into the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +33,6 @@ def output_modifier(string):
|
|||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
@@ -29,8 +42,10 @@ def bot_prefix_modifier(string):
|
|||||||
the prefix text for the Bot and can be used to bias its
|
the prefix text for the Bot and can be used to bias its
|
||||||
behavior.
|
behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if params['activate']:
|
if params['activate']:
|
||||||
|
if params['use custom string']:
|
||||||
|
return f'{string} {params["custom string"].strip()} '
|
||||||
|
else:
|
||||||
return f'{string} {params["bias string"].strip()} '
|
return f'{string} {params["bias string"].strip()} '
|
||||||
else:
|
else:
|
||||||
return string
|
return string
|
||||||
@@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
|
|||||||
def ui():
|
def ui():
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||||
string = gr.Textbox(value=params["bias string"], label='Character bias')
|
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
|
||||||
|
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
|
||||||
|
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
string.change(lambda x: params.update({"bias string": x}), string, None)
|
def update_bias_string(x):
|
||||||
|
if x:
|
||||||
|
params.update({"bias string": x})
|
||||||
|
else:
|
||||||
|
params.update({"bias string": dropdown_string.get()})
|
||||||
|
return x
|
||||||
|
|
||||||
|
def update_custom_string(x):
|
||||||
|
params.update({"custom string": x})
|
||||||
|
|
||||||
|
dropdown_string.change(update_bias_string, dropdown_string, None)
|
||||||
|
custom_string.change(update_custom_string, custom_string, None)
|
||||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||||
|
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
|
||||||
|
|
||||||
|
# Group elements together depending on the selected option
|
||||||
|
def bias_string_group():
|
||||||
|
if use_custom_string.value:
|
||||||
|
return gr.Group([use_custom_string, custom_string])
|
||||||
|
else:
|
||||||
|
return dropdown_string
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.shared as shared
|
|
||||||
from elevenlabslib import ElevenLabsUser
|
from elevenlabslib import ElevenLabsUser
|
||||||
from elevenlabslib.helpers import save_bytes_to_path
|
from elevenlabslib.helpers import save_bytes_to_path
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'activate': True,
|
'activate': True,
|
||||||
'api_key': '12345',
|
'api_key': '12345',
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.shared as shared
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
78
extensions/sd_api_pictures/README.MD
Normal file
78
extensions/sd_api_pictures/README.MD
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
## Description:
|
||||||
|
TL;DR: Lets the bot answer you with a picture!
|
||||||
|
|
||||||
|
Stable Diffusion API pictures for TextGen, v.1.1.0
|
||||||
|
An extension to [oobabooga's textgen-webui](https://github.com/oobabooga/text-generation-webui) allowing you to receive pics generated by [Automatic1111's SD-WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Interface overview</summary>
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
Load it in the `--chat` mode with `--extension sd_api_pictures` alongside `send_pictures` (it's not really required, but completes the picture, *pun intended*).
|
||||||
|
|
||||||
|
The image generation is triggered either:
|
||||||
|
- manually through the 'Force the picture response' button while in `Manual` or `Immersive/Interactive` modes OR
|
||||||
|
- automatically in `Immersive/Interactive` mode if the words `'send|main|message|me'` are followed by `'image|pic|picture|photo|snap|snapshot|selfie|meme'` in the user's prompt
|
||||||
|
- always on in Picturebook/Adventure mode (if not currently suppressed by 'Suppress the picture response')
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
One needs an available instance of Automatic1111's webui running with an `--api` flag. Ain't tested with a notebook / cloud hosted one but should be possible.
|
||||||
|
To run it locally in parallel on the same machine, specify custom `--listen-port` for either Auto1111's or ooba's webUIs.
|
||||||
|
|
||||||
|
## Features:
|
||||||
|
- API detection (press enter in the API box)
|
||||||
|
- VRAM management (model shuffling)
|
||||||
|
- Three different operation modes (manual, interactive, always-on)
|
||||||
|
- persistent settings via settings.json
|
||||||
|
|
||||||
|
The model input is modified only in the interactive mode; other two are unaffected. The output pic description is presented differently for Picture-book / Adventure mode.
|
||||||
|
|
||||||
|
Connection check (insert the Auto1111's address and press Enter):
|
||||||
|

|
||||||
|
|
||||||
|
### Persistents settings
|
||||||
|
|
||||||
|
Create or modify the `settings.json` in the `text-generation-webui` root directory to override the defaults
|
||||||
|
present in script.py, ex:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sd_api_pictures-manage_VRAM": 1,
|
||||||
|
"sd_api_pictures-save_img": 1,
|
||||||
|
"sd_api_pictures-prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful, (solo:1.1)",
|
||||||
|
"sd_api_pictures-sampler_name": "DPM++ 2M Karras"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
will automatically set the `Manage VRAM` & `Keep original images` checkboxes and change the texts in `Prompt Prefix` and `Sampler name` on load.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Demonstrations:
|
||||||
|
|
||||||
|
Those are examples of the version 1.0.0, but the core functionality is still the same
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Conversation 1</summary>
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Conversation 2</summary>
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
@@ -1,34 +1,78 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import date
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.chat as chat
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from modules.models import reload_model, unload_model
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# parameters which can be customized in settings.json of webui
|
# parameters which can be customized in settings.json of webui
|
||||||
params = {
|
params = {
|
||||||
'enable_SD_api': False,
|
|
||||||
'address': 'http://127.0.0.1:7860',
|
'address': 'http://127.0.0.1:7860',
|
||||||
|
'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
|
||||||
|
'manage_VRAM': False,
|
||||||
'save_img': False,
|
'save_img': False,
|
||||||
'SD_model': 'NeverEndingDream', # not really used right now
|
'SD_model': 'NeverEndingDream', # not used right now
|
||||||
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
|
'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
|
||||||
'negative_prompt': '(worst quality, low quality:1.3)',
|
'negative_prompt': '(worst quality, low quality:1.3)',
|
||||||
'side_length': 512,
|
'width': 512,
|
||||||
'restore_faces': False
|
'height': 512,
|
||||||
|
'restore_faces': False,
|
||||||
|
'seed': -1,
|
||||||
|
'sampler_name': 'DDIM',
|
||||||
|
'steps': 32,
|
||||||
|
'cfg_scale': 7
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def give_VRAM_priority(actor):
|
||||||
|
global shared, params
|
||||||
|
|
||||||
|
if actor == 'SD':
|
||||||
|
unload_model()
|
||||||
|
print("Requesting Auto1111 to re-load last checkpoint used...")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
elif actor == 'LLM':
|
||||||
|
print("Requesting Auto1111 to vacate VRAM...")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
reload_model()
|
||||||
|
|
||||||
|
elif actor == 'set':
|
||||||
|
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
elif actor == 'reset':
|
||||||
|
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
|
||||||
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
del response
|
||||||
|
|
||||||
|
|
||||||
|
if params['manage_VRAM']:
|
||||||
|
give_VRAM_priority('set')
|
||||||
|
|
||||||
|
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
|
||||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||||
|
|
||||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||||
picture_response = False # specifies if the next model response should appear as a picture
|
picture_response = False # specifies if the next model response should appear as a picture
|
||||||
pic_id = 0
|
|
||||||
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
@@ -36,7 +80,13 @@ def remove_surrounded_chars(string):
|
|||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
|
||||||
|
def triggers_are_in(string):
|
||||||
|
string = remove_surrounded_chars(string)
|
||||||
|
# regex searches for send|main|message|me (at the end of the word) followed by
|
||||||
|
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
|
||||||
|
# (?aims) are regex parser flags
|
||||||
|
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
@@ -44,55 +94,58 @@ def input_modifier(string):
|
|||||||
This function is applied to your text inputs before
|
This function is applied to your text inputs before
|
||||||
they are fed into the model.
|
they are fed into the model.
|
||||||
"""
|
"""
|
||||||
global params, picture_response
|
|
||||||
if not params['enable_SD_api']:
|
global params
|
||||||
|
|
||||||
|
if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
|
||||||
return string
|
return string
|
||||||
|
|
||||||
commands = ['send', 'mail', 'me']
|
if triggers_are_in(string): # if we're in it, check for trigger words
|
||||||
mediums = ['image', 'pic', 'picture', 'photo']
|
toggle_generation(True)
|
||||||
subjects = ['yourself', 'own']
|
string = string.lower()
|
||||||
lowstr = string.lower()
|
if "of" in string:
|
||||||
|
subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
|
||||||
# TODO: refactor out to separate handler and also replace detection with a regexp
|
string = "Please provide a detailed and vivid description of " + subject
|
||||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
else:
|
||||||
picture_response = True
|
string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now"
|
||||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
|
||||||
shared.processing_message = "*Is sending a picture...*"
|
|
||||||
string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
|
|
||||||
if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
|
|
||||||
string = "Please provide a detailed and vivid description of how you look and what you are wearing"
|
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
# Get and save the Stable Diffusion-generated picture
|
# Get and save the Stable Diffusion-generated picture
|
||||||
|
|
||||||
|
|
||||||
def get_SD_pictures(description):
|
def get_SD_pictures(description):
|
||||||
|
|
||||||
global params, pic_id
|
global params
|
||||||
|
|
||||||
|
if params['manage_VRAM']:
|
||||||
|
give_VRAM_priority('SD')
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"prompt": params['prompt_prefix'] + description,
|
"prompt": params['prompt_prefix'] + description,
|
||||||
"seed": -1,
|
"seed": params['seed'],
|
||||||
"sampler_name": "DPM++ 2M Karras",
|
"sampler_name": params['sampler_name'],
|
||||||
"steps": 32,
|
"steps": params['steps'],
|
||||||
"cfg_scale": 7,
|
"cfg_scale": params['cfg_scale'],
|
||||||
"width": params['side_length'],
|
"width": params['width'],
|
||||||
"height": params['side_length'],
|
"height": params['height'],
|
||||||
"restore_faces": params['restore_faces'],
|
"restore_faces": params['restore_faces'],
|
||||||
"negative_prompt": params['negative_prompt']
|
"negative_prompt": params['negative_prompt']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f'Prompting the image generator via the API on {params["address"]}...')
|
||||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
r = response.json()
|
r = response.json()
|
||||||
|
|
||||||
visible_result = ""
|
visible_result = ""
|
||||||
for img_str in r['images']:
|
for img_str in r['images']:
|
||||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
||||||
if params['save_img']:
|
if params['save_img']:
|
||||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
|
||||||
|
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
|
||||||
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
image.save(output_file.as_posix())
|
image.save(output_file.as_posix())
|
||||||
pic_id += 1
|
visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
|
||||||
|
else:
|
||||||
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
image.thumbnail((300, 300))
|
image.thumbnail((300, 300))
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
@@ -102,17 +155,19 @@ def get_SD_pictures(description):
|
|||||||
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
||||||
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
||||||
|
|
||||||
|
if params['manage_VRAM']:
|
||||||
|
give_VRAM_priority('LLM')
|
||||||
|
|
||||||
return visible_result
|
return visible_result
|
||||||
|
|
||||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||||
# and replace it with 'text' for the purposes of logging?
|
# and replace it with 'text' for the purposes of logging?
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
"""
|
"""
|
||||||
global pic_id, picture_response, streaming_state
|
|
||||||
|
global picture_response, params
|
||||||
|
|
||||||
if not picture_response:
|
if not picture_response:
|
||||||
return string
|
return string
|
||||||
@@ -125,17 +180,18 @@ def output_modifier(string):
|
|||||||
|
|
||||||
if string == '':
|
if string == '':
|
||||||
string = 'no viable description in reply, try regenerating'
|
string = 'no viable description in reply, try regenerating'
|
||||||
|
return string
|
||||||
|
|
||||||
# I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
|
text = ""
|
||||||
text = f'*Description: "{string}"*'
|
if (params['mode'] < 2):
|
||||||
|
toggle_generation(False)
|
||||||
|
text = f'*Sends a picture which portrays: “{string}”*'
|
||||||
|
else:
|
||||||
|
text = string
|
||||||
|
|
||||||
image = get_SD_pictures(string)
|
string = get_SD_pictures(string) + "\n" + text
|
||||||
|
|
||||||
picture_response = False
|
return string
|
||||||
|
|
||||||
shared.processing_message = "*Is typing...*"
|
|
||||||
shared.args.no_stream = streaming_state
|
|
||||||
return image + "\n" + text
|
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
@@ -148,42 +204,91 @@ def bot_prefix_modifier(string):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def force_pic():
|
def toggle_generation(*args):
|
||||||
global picture_response
|
global picture_response, shared, streaming_state
|
||||||
picture_response = True
|
|
||||||
|
if not args:
|
||||||
|
picture_response = not picture_response
|
||||||
|
else:
|
||||||
|
picture_response = args[0]
|
||||||
|
|
||||||
|
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||||
|
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
||||||
|
|
||||||
|
|
||||||
|
def filter_address(address):
|
||||||
|
address = address.strip()
|
||||||
|
# address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
|
||||||
|
address = re.sub('\/$', '', address) # remove trailing /s
|
||||||
|
if not address.startswith('http'):
|
||||||
|
address = 'http://' + address
|
||||||
|
return address
|
||||||
|
|
||||||
|
|
||||||
|
def SD_api_address_update(address):
|
||||||
|
|
||||||
|
global params
|
||||||
|
|
||||||
|
msg = "✔️ SD API is found on:"
|
||||||
|
address = filter_address(address)
|
||||||
|
params.update({"address": address})
|
||||||
|
try:
|
||||||
|
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
||||||
|
response.raise_for_status()
|
||||||
|
# r = response.json()
|
||||||
|
except:
|
||||||
|
msg = "❌ No SD API endpoint on:"
|
||||||
|
|
||||||
|
return gr.Textbox.update(label=msg)
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
with gr.Accordion("Stable Diffusion api integration", open=True):
|
# gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
|
||||||
|
with gr.Accordion("Parameters", open=True):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
|
||||||
enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
|
mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index")
|
||||||
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
|
with gr.Column(scale=1, min_width=300):
|
||||||
with gr.Column():
|
manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
|
||||||
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
|
save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
|
||||||
|
|
||||||
with gr.Row():
|
force_pic = gr.Button("Force the picture response")
|
||||||
force_btn = gr.Button("Force the next response to be a picture")
|
suppr_pic = gr.Button("Suppress the picture response")
|
||||||
generate_now_btn = gr.Button("Generate an image response to the input")
|
|
||||||
|
|
||||||
with gr.Accordion("Generation parameters", open=False):
|
with gr.Accordion("Generation parameters", open=False):
|
||||||
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||||
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
|
sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler')
|
||||||
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
|
with gr.Column():
|
||||||
|
width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
|
||||||
|
height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
|
||||||
|
with gr.Row():
|
||||||
|
steps = gr.Number(label="Steps:", value=params['steps'])
|
||||||
|
seed = gr.Number(label="Seed:", value=params['seed'])
|
||||||
|
cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
|
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||||
|
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
||||||
|
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
|
||||||
|
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
|
||||||
|
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
|
||||||
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
||||||
address.change(lambda x: params.update({"address": x}), address, None)
|
|
||||||
|
address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
|
||||||
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
|
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
|
||||||
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
|
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
|
||||||
dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
|
width.change(lambda x: params.update({"width": x}), width, None)
|
||||||
# model.change(lambda x: params.update({"SD_model": x}), model, None)
|
height.change(lambda x: params.update({"height": x}), height, None)
|
||||||
|
|
||||||
force_btn.click(force_pic)
|
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
|
||||||
generate_now_btn.click(force_pic)
|
steps.change(lambda x: params.update({"steps": x}), steps, None)
|
||||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
seed.change(lambda x: params.update({"seed": x}), seed, None)
|
||||||
|
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
|
||||||
|
|
||||||
|
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
|
||||||
|
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def caption_image(raw_image):
|
|||||||
|
|
||||||
|
|
||||||
def generate_chat_picture(picture, name1, name2):
|
def generate_chat_picture(picture, name1, name2):
|
||||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
text = f'*{name1} sends {name2} a picture that contains the following: “{caption_image(picture)}”*'
|
||||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||||
picture.thumbnail((300, 300))
|
picture.thumbnail((300, 300))
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
ipython
|
ipython
|
||||||
|
num2words
|
||||||
omegaconf
|
omegaconf
|
||||||
pydub
|
pydub
|
||||||
PyYAML
|
PyYAML
|
||||||
torch
|
|
||||||
torchaudio
|
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
import re
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.chat as chat
|
|
||||||
import modules.shared as shared
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from extensions.silero_tts import tts_preprocessor
|
||||||
|
from modules import chat, shared
|
||||||
|
from modules.html_generator import chat_html_wrapper
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'activate': True,
|
'activate': True,
|
||||||
'speaker': 'en_56',
|
'speaker': 'en_56',
|
||||||
@@ -20,6 +22,7 @@ params = {
|
|||||||
'autoplay': True,
|
'autoplay': True,
|
||||||
'voice_pitch': 'medium',
|
'voice_pitch': 'medium',
|
||||||
'voice_speed': 'medium',
|
'voice_speed': 'medium',
|
||||||
|
'local_cache_path': '' # User can override the default cache path to something other via settings.json
|
||||||
}
|
}
|
||||||
|
|
||||||
current_params = params.copy()
|
current_params = params.copy()
|
||||||
@@ -37,26 +40,31 @@ table = str.maketrans({
|
|||||||
'"': """,
|
'"': """,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def xmlesc(txt):
|
def xmlesc(txt):
|
||||||
return txt.translate(table)
|
return txt.translate(table)
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
|
torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
|
||||||
|
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
||||||
|
if Path(model_path).is_file():
|
||||||
|
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
||||||
|
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
||||||
|
else:
|
||||||
|
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
||||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||||
model.to(params['device'])
|
model.to(params['device'])
|
||||||
return model
|
return model
|
||||||
model = load_model()
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
|
||||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
|
||||||
|
|
||||||
def remove_tts_from_history(name1, name2):
|
def remove_tts_from_history(name1, name2, mode):
|
||||||
for i, entry in enumerate(shared.history['internal']):
|
for i, entry in enumerate(shared.history['internal']):
|
||||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
def toggle_text_in_history(name1, name2):
|
|
||||||
|
def toggle_text_in_history(name1, name2, mode):
|
||||||
for i, entry in enumerate(shared.history['visible']):
|
for i, entry in enumerate(shared.history['visible']):
|
||||||
visible_reply = entry[1]
|
visible_reply = entry[1]
|
||||||
if visible_reply.startswith('<audio'):
|
if visible_reply.startswith('<audio'):
|
||||||
@@ -65,7 +73,8 @@ def toggle_text_in_history(name1, name2):
|
|||||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||||
else:
|
else:
|
||||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def input_modifier(string):
|
def input_modifier(string):
|
||||||
"""
|
"""
|
||||||
@@ -75,12 +84,13 @@ def input_modifier(string):
|
|||||||
|
|
||||||
# Remove autoplay from the last reply
|
# Remove autoplay from the last reply
|
||||||
if shared.is_chat() and len(shared.history['internal']) > 0:
|
if shared.is_chat() and len(shared.history['internal']) > 0:
|
||||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
|
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
|
||||||
|
|
||||||
shared.processing_message = "*Is recording a voice message...*"
|
shared.processing_message = "*Is recording a voice message...*"
|
||||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string):
|
def output_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
@@ -94,15 +104,11 @@ def output_modifier(string):
|
|||||||
current_params = params.copy()
|
current_params = params.copy()
|
||||||
break
|
break
|
||||||
|
|
||||||
if params['activate'] == False:
|
if not params['activate']:
|
||||||
return string
|
return string
|
||||||
|
|
||||||
original_string = string
|
original_string = string
|
||||||
string = remove_surrounded_chars(string)
|
string = tts_preprocessor.preprocess(string)
|
||||||
string = string.replace('"', '')
|
|
||||||
string = string.replace('“', '')
|
|
||||||
string = string.replace('\n', ' ')
|
|
||||||
string = string.strip()
|
|
||||||
|
|
||||||
if string == '':
|
if string == '':
|
||||||
string = '*Empty reply, try regenerating*'
|
string = '*Empty reply, try regenerating*'
|
||||||
@@ -121,6 +127,7 @@ def output_modifier(string):
|
|||||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
def bot_prefix_modifier(string):
|
def bot_prefix_modifier(string):
|
||||||
"""
|
"""
|
||||||
This function is only applied in chat mode. It modifies
|
This function is only applied in chat mode. It modifies
|
||||||
@@ -130,17 +137,25 @@ def bot_prefix_modifier(string):
|
|||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
global model
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
with gr.Accordion("Silero TTS"):
|
with gr.Accordion("Silero TTS"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||||
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
||||||
|
|
||||||
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
||||||
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
||||||
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
convert = gr.Button('Permanently replace audios with the message texts')
|
convert = gr.Button('Permanently replace audios with the message texts')
|
||||||
convert_cancel = gr.Button('Cancel', visible=False)
|
convert_cancel = gr.Button('Cancel', visible=False)
|
||||||
@@ -148,16 +163,16 @@ def ui():
|
|||||||
|
|
||||||
# Convert history with confirmation
|
# Convert history with confirmation
|
||||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||||
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||||
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||||
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||||
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||||
|
|
||||||
# Toggle message text in history
|
# Toggle message text in history
|
||||||
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
||||||
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||||
|
|||||||
81
extensions/silero_tts/test_tts.py
Normal file
81
extensions/silero_tts/test_tts.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tts_preprocessor
|
||||||
|
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'activate': True,
|
||||||
|
'speaker': 'en_49',
|
||||||
|
'language': 'en',
|
||||||
|
'model_id': 'v3_en',
|
||||||
|
'sample_rate': 48000,
|
||||||
|
'device': 'cpu',
|
||||||
|
'show_text': True,
|
||||||
|
'autoplay': True,
|
||||||
|
'voice_pitch': 'medium',
|
||||||
|
'voice_speed': 'medium',
|
||||||
|
}
|
||||||
|
|
||||||
|
current_params = params.copy()
|
||||||
|
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||||
|
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||||
|
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||||
|
|
||||||
|
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||||
|
table = str.maketrans({
|
||||||
|
"<": "<",
|
||||||
|
">": ">",
|
||||||
|
"&": "&",
|
||||||
|
"'": "'",
|
||||||
|
'"': """,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def xmlesc(txt):
|
||||||
|
return txt.translate(table)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||||
|
model.to(params['device'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
model = load_model()
|
||||||
|
|
||||||
|
|
||||||
|
def output_modifier(string):
|
||||||
|
"""
|
||||||
|
This function is applied to the model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global model, current_params
|
||||||
|
|
||||||
|
original_string = string
|
||||||
|
string = tts_preprocessor.preprocess(string)
|
||||||
|
processed_string = string
|
||||||
|
|
||||||
|
if string == '':
|
||||||
|
string = '*Empty reply, try regenerating*'
|
||||||
|
else:
|
||||||
|
output_file = Path(f'extensions/silero_tts/outputs/test_{int(time.time())}.wav')
|
||||||
|
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||||
|
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||||
|
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||||
|
|
||||||
|
autoplay = 'autoplay' if params['autoplay'] else ''
|
||||||
|
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
||||||
|
|
||||||
|
if params['show_text']:
|
||||||
|
string += f'\n\n{original_string}\n\nProcessed:\n{processed_string}'
|
||||||
|
|
||||||
|
print(string)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
output_modifier(sys.argv[1])
|
||||||
194
extensions/silero_tts/tts_preprocessor.py
Normal file
194
extensions/silero_tts/tts_preprocessor.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from num2words import num2words
|
||||||
|
|
||||||
|
punctuation = r'[\s,.?!/)\'\]>]'
|
||||||
|
alphabet_map = {
|
||||||
|
"A": " Ei ",
|
||||||
|
"B": " Bee ",
|
||||||
|
"C": " See ",
|
||||||
|
"D": " Dee ",
|
||||||
|
"E": " Eee ",
|
||||||
|
"F": " Eff ",
|
||||||
|
"G": " Jee ",
|
||||||
|
"H": " Eich ",
|
||||||
|
"I": " Eye ",
|
||||||
|
"J": " Jay ",
|
||||||
|
"K": " Kay ",
|
||||||
|
"L": " El ",
|
||||||
|
"M": " Emm ",
|
||||||
|
"N": " Enn ",
|
||||||
|
"O": " Ohh ",
|
||||||
|
"P": " Pee ",
|
||||||
|
"Q": " Queue ",
|
||||||
|
"R": " Are ",
|
||||||
|
"S": " Ess ",
|
||||||
|
"T": " Tee ",
|
||||||
|
"U": " You ",
|
||||||
|
"V": " Vee ",
|
||||||
|
"W": " Double You ",
|
||||||
|
"X": " Ex ",
|
||||||
|
"Y": " Why ",
|
||||||
|
"Z": " Zed " # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(string):
|
||||||
|
# the order for some of these matter
|
||||||
|
# For example, you need to remove the commas in numbers before expanding them
|
||||||
|
string = remove_surrounded_chars(string)
|
||||||
|
string = string.replace('"', '')
|
||||||
|
string = string.replace('\u201D', '').replace('\u201C', '') # right and left quote
|
||||||
|
string = string.replace('\u201F', '') # italic looking quote
|
||||||
|
string = string.replace('\n', ' ')
|
||||||
|
string = convert_num_locale(string)
|
||||||
|
string = replace_negative(string)
|
||||||
|
string = replace_roman(string)
|
||||||
|
string = hyphen_range_to(string)
|
||||||
|
string = num_to_words(string)
|
||||||
|
|
||||||
|
# TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
|
||||||
|
# try to say the abbreviation or spell it out as I've done below is not agreed upon
|
||||||
|
|
||||||
|
# For now, expand abbreviations to pronunciations
|
||||||
|
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation
|
||||||
|
string = replace_abbreviations(string)
|
||||||
|
string = replace_lowercase_abbreviations(string)
|
||||||
|
|
||||||
|
# cleanup whitespaces
|
||||||
|
# remove whitespace before punctuation
|
||||||
|
string = re.sub(rf'\s+({punctuation})', r'\1', string)
|
||||||
|
string = string.strip()
|
||||||
|
# compact whitespace
|
||||||
|
string = ' '.join(string.split())
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def remove_surrounded_chars(string):
|
||||||
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
|
return re.sub(r'\*[^*]*?(\*|$)', '', string)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_num_locale(text):
|
||||||
|
# This detects locale and converts it to American without comma separators
|
||||||
|
pattern = re.compile(r'(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)')
|
||||||
|
result = text
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start] + result[start:end].replace('.', '').replace(',', '.') + result[end:len(result)]
|
||||||
|
|
||||||
|
# removes comma separators from existing American numbers
|
||||||
|
pattern = re.compile(r'(\d),(\d)')
|
||||||
|
result = pattern.sub(r'\1\2', result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_negative(string):
|
||||||
|
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five
|
||||||
|
return re.sub(rf'(\s)(-)(\d+)({punctuation})', r'\1negative \3\4', string)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_roman(string):
|
||||||
|
# find a string of roman numerals.
|
||||||
|
# Only 2 or more, to avoid capturing I and single character abbreviations, like names
|
||||||
|
pattern = re.compile(rf'\s[IVXLCDM]{{2,}}{punctuation}')
|
||||||
|
result = string
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start + 1] + str(roman_to_int(result[start + 1:end - 1])) + result[end - 1:len(result)]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def roman_to_int(s):
|
||||||
|
rom_val = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
|
||||||
|
int_val = 0
|
||||||
|
for i in range(len(s)):
|
||||||
|
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]:
|
||||||
|
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]]
|
||||||
|
else:
|
||||||
|
int_val += rom_val[s[i]]
|
||||||
|
return int_val
|
||||||
|
|
||||||
|
|
||||||
|
def hyphen_range_to(text):
|
||||||
|
pattern = re.compile(r'(\d+)[-–](\d+)')
|
||||||
|
result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def num_to_words(text):
|
||||||
|
# 1000 or 10.23
|
||||||
|
pattern = re.compile(r'\d+\.\d+|\d+')
|
||||||
|
result = pattern.sub(lambda x: num2words(float(x.group())), text)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_abbreviations(string):
|
||||||
|
# abbreviations 1 to 4 characters long. It will get things like A and I, but those are pronounced with their letter
|
||||||
|
pattern = re.compile(rf'(^|[\s(.\'\[<])([A-Z]{{1,4}})({punctuation}|$)')
|
||||||
|
result = string
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_lowercase_abbreviations(string):
|
||||||
|
# abbreviations 1 to 4 characters long, separated by dots i.e. e.g.
|
||||||
|
pattern = re.compile(rf'(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({punctuation}|$)')
|
||||||
|
result = string
|
||||||
|
while True:
|
||||||
|
match = pattern.search(result)
|
||||||
|
if match is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
start = match.start()
|
||||||
|
end = match.end()
|
||||||
|
result = result[0:start] + replace_abbreviation(result[start:end].upper()) + result[end:len(result)]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def replace_abbreviation(string):
|
||||||
|
result = ""
|
||||||
|
for char in string:
|
||||||
|
result += match_mapping(char)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def match_mapping(char):
|
||||||
|
for mapping in alphabet_map.keys():
|
||||||
|
if char == mapping:
|
||||||
|
return alphabet_map[char]
|
||||||
|
|
||||||
|
return char
|
||||||
|
|
||||||
|
|
||||||
|
def __main__(args):
|
||||||
|
print(preprocess(args[1]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
__main__(sys.argv)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import speech_recognition as sr
|
import speech_recognition as sr
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
input_hijack = {
|
input_hijack = {
|
||||||
'state': False,
|
'state': False,
|
||||||
@@ -7,7 +8,7 @@ input_hijack = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_stt(audio, text_state=""):
|
def do_stt(audio):
|
||||||
transcription = ""
|
transcription = ""
|
||||||
r = sr.Recognizer()
|
r = sr.Recognizer()
|
||||||
|
|
||||||
@@ -21,34 +22,23 @@ def do_stt(audio, text_state=""):
|
|||||||
except sr.RequestError as e:
|
except sr.RequestError as e:
|
||||||
print("Could not request results from Whisper", e)
|
print("Could not request results from Whisper", e)
|
||||||
|
|
||||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
return transcription
|
||||||
|
|
||||||
text_state += transcription + " "
|
|
||||||
return text_state, text_state
|
|
||||||
|
|
||||||
|
|
||||||
def update_hijack(val):
|
def auto_transcribe(audio, auto_submit):
|
||||||
input_hijack.update({"state": True, "value": [val, val]})
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def auto_transcribe(audio, audio_auto, text_state=""):
|
|
||||||
if audio is None:
|
if audio is None:
|
||||||
return "", ""
|
return "", ""
|
||||||
if audio_auto:
|
|
||||||
return do_stt(audio, text_state)
|
transcription = do_stt(audio)
|
||||||
return "", ""
|
if auto_submit:
|
||||||
|
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||||
|
|
||||||
|
return transcription, None
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
tr_state = gr.State(value="")
|
|
||||||
output_transcription = gr.Textbox(label="STT-Input",
|
|
||||||
placeholder="Speech Preview. Click \"Generate\" to send",
|
|
||||||
interactive=True)
|
|
||||||
output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
|
|
||||||
audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
audio = gr.Audio(source="microphone")
|
audio = gr.Audio(source="microphone")
|
||||||
audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
|
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=True)
|
||||||
transcribe_button = gr.Button(value="Transcribe")
|
audio.change(fn=auto_transcribe, inputs=[audio, auto_submit], outputs=[shared.gradio['textbox'], audio])
|
||||||
transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])
|
audio.change(None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
|
||||||
|
|||||||
@@ -100,10 +100,10 @@ def load_quantized(model_name):
|
|||||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||||
pt_path = None
|
pt_path = None
|
||||||
|
|
||||||
if len(found_pts) == 1:
|
if len(found_pts) > 0:
|
||||||
pt_path = found_pts[0]
|
pt_path = found_pts[-1]
|
||||||
elif len(found_safetensors) == 1:
|
elif len(found_safetensors) > 0:
|
||||||
pt_path = found_safetensors[0]
|
pt_path = found_safetensors[-1]
|
||||||
else:
|
else:
|
||||||
if path_to_model.name.lower().startswith('llama-7b'):
|
if path_to_model.name.lower().startswith('llama-7b'):
|
||||||
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
||||||
@@ -119,13 +119,14 @@ def load_quantized(model_name):
|
|||||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||||
if path.exists():
|
if path.exists():
|
||||||
print(f"Found {path}")
|
|
||||||
pt_path = path
|
pt_path = path
|
||||||
break
|
break
|
||||||
|
|
||||||
if not pt_path:
|
if not pt_path:
|
||||||
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||||
exit()
|
exit()
|
||||||
|
else:
|
||||||
|
print(f"Found the following quantized model: {pt_path}")
|
||||||
|
|
||||||
# qwopqwop200's offload
|
# qwopqwop200's offload
|
||||||
if model_type == 'llama' and shared.args.pre_layer:
|
if model_type == 'llama' and shared.args.pre_layer:
|
||||||
|
|||||||
@@ -4,14 +4,7 @@ import torch
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.models import load_model
|
from modules.models import reload_model
|
||||||
from modules.text_generation import clear_torch_cache
|
|
||||||
|
|
||||||
|
|
||||||
def reload_model():
|
|
||||||
shared.model = shared.tokenizer = None
|
|
||||||
clear_torch_cache()
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(lora_name):
|
def add_lora_to_model(lora_name):
|
||||||
|
|||||||
208
modules/chat.py
208
modules/chat.py
@@ -12,53 +12,59 @@ from PIL import Image
|
|||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import (fix_newlines, chat_html_wrapper,
|
from modules.html_generator import (chat_html_wrapper, fix_newlines,
|
||||||
make_thumbnail)
|
make_thumbnail)
|
||||||
from modules.text_generation import (encode, generate_reply,
|
from modules.text_generation import (encode, generate_reply,
|
||||||
get_max_prompt_length)
|
get_max_prompt_length)
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
|
def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
|
||||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
|
||||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||||
|
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||||
rows = [f"{context.strip()}\n"]
|
is_instruct = state['mode'] == 'instruct'
|
||||||
|
rows = [f"{state['context'].strip()}\n"]
|
||||||
|
|
||||||
# Finding the maximum prompt size
|
# Finding the maximum prompt size
|
||||||
|
chat_prompt_size = state['chat_prompt_size']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||||
|
|
||||||
if is_instruct:
|
if is_instruct:
|
||||||
prefix1 = f"{name1}\n"
|
prefix1 = f"{state['name1']}\n"
|
||||||
prefix2 = f"{name2}\n"
|
prefix2 = f"{state['name2']}\n"
|
||||||
else:
|
else:
|
||||||
prefix1 = f"{name1}: "
|
prefix1 = f"{state['name1']}: "
|
||||||
prefix2 = f"{name2}: "
|
prefix2 = f"{state['name2']}: "
|
||||||
|
|
||||||
i = len(shared.history['internal']) - 1
|
i = len(shared.history['internal']) - 1
|
||||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
if _continue and i == len(shared.history['internal']) - 1:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||||
|
else:
|
||||||
|
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||||
string = shared.history['internal'][i][0]
|
string = shared.history['internal'][i][0]
|
||||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|
||||||
if impersonate:
|
if impersonate:
|
||||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||||
limit = 2
|
limit = 2
|
||||||
|
elif _continue:
|
||||||
|
limit = 3
|
||||||
else:
|
else:
|
||||||
# Adding the user message
|
# Adding the user message
|
||||||
user_input = fix_newlines(user_input)
|
user_input = fix_newlines(user_input)
|
||||||
if len(user_input) > 0:
|
if len(user_input) > 0:
|
||||||
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
|
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
||||||
|
|
||||||
# Adding the Character prefix
|
# Adding the Character prefix
|
||||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||||
limit = 3
|
limit = 3
|
||||||
|
|
||||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
|
||||||
rows.pop(1)
|
rows.pop(1)
|
||||||
prompt = ''.join(rows)
|
prompt = ''.join(rows)
|
||||||
|
|
||||||
@@ -68,16 +74,26 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
def get_stopping_strings(state):
|
||||||
next_character_found = False
|
if state['mode'] == 'instruct':
|
||||||
|
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||||
|
else:
|
||||||
|
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||||
|
stopping_strings += state['custom_stopping_strings']
|
||||||
|
return stopping_strings
|
||||||
|
|
||||||
if stop_at_newline:
|
|
||||||
|
def extract_message_from_reply(reply, state):
|
||||||
|
next_character_found = False
|
||||||
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
|
if state['stop_at_newline']:
|
||||||
lines = reply.split('\n')
|
lines = reply.split('\n')
|
||||||
reply = lines[0].strip()
|
reply = lines[0].strip()
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
next_character_found = True
|
next_character_found = True
|
||||||
else:
|
else:
|
||||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
for string in stopping_strings:
|
||||||
idx = reply.find(string)
|
idx = reply.find(string)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
reply = reply[:idx]
|
reply = reply[:idx]
|
||||||
@@ -86,7 +102,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
# If something like "\nYo" is generated just before "\nYou:"
|
# If something like "\nYo" is generated just before "\nYou:"
|
||||||
# is completed, trim it
|
# is completed, trim it
|
||||||
if not next_character_found:
|
if not next_character_found:
|
||||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
for string in stopping_strings:
|
||||||
for j in range(len(string) - 1, 0, -1):
|
for j in range(len(string) - 1, 0, -1):
|
||||||
if reply[-j:] == string[:j]:
|
if reply[-j:] == string[:j]:
|
||||||
reply = reply[:-j]
|
reply = reply[:-j]
|
||||||
@@ -99,20 +115,17 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||||
if mode == 'instruct':
|
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
|
||||||
else:
|
|
||||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
|
||||||
|
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
# Defining some variables
|
||||||
name1_original = name1
|
cumulative_reply = ''
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||||
name1 = "You"
|
just_started = True
|
||||||
|
visible_text = custom_generate_chat_prompt = None
|
||||||
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
# Check if any extension wants to hijack this function call
|
# Check if any extension wants to hijack this function call
|
||||||
visible_text = None
|
|
||||||
custom_generate_chat_prompt = None
|
|
||||||
for extension, _ in extensions_module.iterator():
|
for extension, _ in extensions_module.iterator():
|
||||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||||
extension.input_hijack['state'] = False
|
extension.input_hijack['state'] = False
|
||||||
@@ -122,29 +135,29 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
if visible_text is None:
|
if visible_text is None:
|
||||||
visible_text = text
|
visible_text = text
|
||||||
|
if not _continue:
|
||||||
text = apply_extensions(text, "input")
|
text = apply_extensions(text, "input")
|
||||||
|
|
||||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
# Generating the prompt
|
||||||
|
kwargs = {'_continue': _continue}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = custom_generate_chat_prompt(text, state, **kwargs)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
if not regenerate:
|
if not any((regenerate, _continue)):
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
cumulative_reply = ''
|
for i in range(state['chat_generation_attempts']):
|
||||||
just_started = True
|
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
|
|
||||||
# Extracting the reply
|
# Extracting the reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
visible_reply = apply_extensions(visible_reply, "output")
|
visible_reply = apply_extensions(visible_reply, "output")
|
||||||
|
|
||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
@@ -153,9 +166,15 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
return shared.history['visible']
|
return shared.history['visible']
|
||||||
if just_started:
|
if just_started:
|
||||||
just_started = False
|
just_started = False
|
||||||
|
if not _continue:
|
||||||
shared.history['internal'].append(['', ''])
|
shared.history['internal'].append(['', ''])
|
||||||
shared.history['visible'].append(['', ''])
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
|
if _continue:
|
||||||
|
sep = list(map(lambda x: ' ' if len(x) > 0 and x[-1] != ' ' else '', last_reply))
|
||||||
|
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||||
|
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||||
|
else:
|
||||||
shared.history['internal'][-1] = [text, reply]
|
shared.history['internal'][-1] = [text, reply]
|
||||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
@@ -169,27 +188,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
|
|
||||||
|
|
||||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def impersonate_wrapper(text, state):
|
||||||
if mode == 'instruct':
|
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
|
||||||
else:
|
|
||||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
|
||||||
|
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
# Defining some variables
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
cumulative_reply = ''
|
||||||
name1 = "You"
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
|
prompt = generate_chat_prompt(text, state, impersonate=True)
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
stopping_strings = get_stopping_strings(state)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
yield shared.processing_message
|
yield shared.processing_message
|
||||||
|
|
||||||
cumulative_reply = ''
|
for i in range(state['chat_generation_attempts']):
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
yield reply
|
yield reply
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
@@ -200,22 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def cai_chatbot_wrapper(text, state):
|
||||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
for history in chatbot_wrapper(text, state):
|
||||||
yield chat_html_wrapper(history, name1, name2, mode)
|
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def regenerate_wrapper(text, state):
|
||||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
else:
|
else:
|
||||||
last_visible = shared.history['visible'].pop()
|
last_visible = shared.history['visible'].pop()
|
||||||
last_internal = shared.history['internal'].pop()
|
last_internal = shared.history['internal'].pop()
|
||||||
# Yield '*Is typing...*'
|
# Yield '*Is typing...*'
|
||||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
|
||||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
|
||||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
|
def continue_wrapper(text, state):
|
||||||
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
else:
|
||||||
|
# Yield ' ...'
|
||||||
|
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
|
||||||
|
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
|
||||||
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2, mode):
|
def remove_last_message(name1, name2, mode):
|
||||||
@@ -243,6 +267,21 @@ def replace_last_reply(text, name1, name2, mode):
|
|||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def send_dummy_message(text, name1, name2, mode):
|
||||||
|
shared.history['visible'].append([text, ''])
|
||||||
|
shared.history['internal'].append([apply_extensions(text, "input"), ''])
|
||||||
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
|
def send_dummy_reply(text, name1, name2, mode):
|
||||||
|
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '':
|
||||||
|
shared.history['visible'].append(['', ''])
|
||||||
|
shared.history['internal'].append(['', ''])
|
||||||
|
shared.history['visible'][-1][1] = text
|
||||||
|
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||||
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def clear_html():
|
def clear_html():
|
||||||
return chat_html_wrapper([], "", "")
|
return chat_html_wrapper([], "", "")
|
||||||
|
|
||||||
@@ -255,6 +294,9 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
|
# Save cleared logs
|
||||||
|
save_history(mode)
|
||||||
|
|
||||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -264,7 +306,7 @@ def redraw_html(name1, name2, mode):
|
|||||||
|
|
||||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||||
history = []
|
history = []
|
||||||
|
messages = []
|
||||||
dialogue = re.sub('<START>', '', dialogue)
|
dialogue = re.sub('<START>', '', dialogue)
|
||||||
dialogue = re.sub('<start>', '', dialogue)
|
dialogue = re.sub('<start>', '', dialogue)
|
||||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||||
@@ -273,7 +315,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||||||
if len(idx) == 0:
|
if len(idx) == 0:
|
||||||
return history
|
return history
|
||||||
|
|
||||||
messages = []
|
|
||||||
for i in range(len(idx) - 1):
|
for i in range(len(idx) - 1):
|
||||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||||
messages.append(dialogue[idx[-1]:].strip())
|
messages.append(dialogue[idx[-1]:].strip())
|
||||||
@@ -300,7 +341,14 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
|||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
def save_history(timestamp=True):
|
def save_history(mode, timestamp=False):
|
||||||
|
# Instruct mode histories should not be saved as if
|
||||||
|
# Alpaca or Vicuna were characters
|
||||||
|
if mode == 'instruct':
|
||||||
|
if not timestamp:
|
||||||
|
return
|
||||||
|
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
|
else:
|
||||||
if timestamp:
|
if timestamp:
|
||||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
else:
|
else:
|
||||||
@@ -309,6 +357,7 @@ def save_history(timestamp=True):
|
|||||||
Path('logs').mkdir()
|
Path('logs').mkdir()
|
||||||
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
||||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||||
|
|
||||||
return Path(f'logs/{fname}')
|
return Path(f'logs/{fname}')
|
||||||
|
|
||||||
|
|
||||||
@@ -322,16 +371,6 @@ def load_history(file, name1, name2):
|
|||||||
shared.history['visible'] = j['data_visible']
|
shared.history['visible'] = j['data_visible']
|
||||||
else:
|
else:
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||||
# Compatibility with Pygmalion AI's official web UI
|
|
||||||
elif 'chat' in j:
|
|
||||||
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
|
||||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
|
||||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
|
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
|
||||||
shared.history['visible'][0][0] = ''
|
|
||||||
else:
|
|
||||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
|
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
|
||||||
except:
|
except:
|
||||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||||
@@ -367,8 +406,6 @@ def generate_pfp_cache(character):
|
|||||||
|
|
||||||
def load_character(character, name1, name2, mode):
|
def load_character(character, name1, name2, mode):
|
||||||
shared.character = character
|
shared.character = character
|
||||||
shared.history['internal'] = []
|
|
||||||
shared.history['visible'] = []
|
|
||||||
context = greeting = end_of_turn = ""
|
context = greeting = end_of_turn = ""
|
||||||
greeting_field = 'greeting'
|
greeting_field = 'greeting'
|
||||||
picture = None
|
picture = None
|
||||||
@@ -413,13 +450,22 @@ def load_character(character, name1, name2, mode):
|
|||||||
greeting = shared.settings['greeting']
|
greeting = shared.settings['greeting']
|
||||||
end_of_turn = shared.settings['end_of_turn']
|
end_of_turn = shared.settings['end_of_turn']
|
||||||
|
|
||||||
|
if mode != 'instruct':
|
||||||
|
shared.history['internal'] = []
|
||||||
|
shared.history['visible'] = []
|
||||||
|
|
||||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||||
elif greeting != "":
|
else:
|
||||||
|
# Insert greeting if it exists
|
||||||
|
if greeting != "":
|
||||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||||
|
|
||||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
# Create .json log files since they don't already exist
|
||||||
|
save_history(mode)
|
||||||
|
|
||||||
|
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||||
|
|
||||||
|
|
||||||
def load_default_history(name1, name2):
|
def load_default_history(name1, name2):
|
||||||
|
|||||||
@@ -11,29 +11,31 @@ setup_called = set()
|
|||||||
|
|
||||||
|
|
||||||
def load_extensions():
|
def load_extensions():
|
||||||
global state
|
global state, setup_called
|
||||||
for i, name in enumerate(shared.args.extensions):
|
for i, name in enumerate(shared.args.extensions):
|
||||||
if name in available_extensions:
|
if name in available_extensions:
|
||||||
print(f'Loading the extension "{name}"... ', end='')
|
print(f'Loading the extension "{name}"... ', end='')
|
||||||
try:
|
try:
|
||||||
exec(f"import extensions.{name}.script")
|
exec(f"import extensions.{name}.script")
|
||||||
|
extension = eval(f"extensions.{name}.script")
|
||||||
|
if extension not in setup_called and hasattr(extension, "setup"):
|
||||||
|
setup_called.add(extension)
|
||||||
|
extension.setup()
|
||||||
state[name] = [True, i]
|
state[name] = [True, i]
|
||||||
print('Ok.')
|
print('Ok.')
|
||||||
except:
|
except:
|
||||||
print('Fail.')
|
print('Fail.')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# This iterator returns the extensions in the order specified in the command-line
|
# This iterator returns the extensions in the order specified in the command-line
|
||||||
|
|
||||||
|
|
||||||
def iterator():
|
def iterator():
|
||||||
for name in sorted(state, key=lambda x: state[x][1]):
|
for name in sorted(state, key=lambda x: state[x][1]):
|
||||||
if state[name][0] == True:
|
if state[name][0]:
|
||||||
yield eval(f"extensions.{name}.script"), name
|
yield eval(f"extensions.{name}.script"), name
|
||||||
|
|
||||||
|
|
||||||
# Extension functions that map string -> string
|
# Extension functions that map string -> string
|
||||||
|
|
||||||
|
|
||||||
def apply_extensions(text, typ):
|
def apply_extensions(text, typ):
|
||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||||
@@ -57,14 +59,9 @@ def create_extensions_block():
|
|||||||
extension.params[param] = shared.settings[_id]
|
extension.params[param] = shared.settings[_id]
|
||||||
|
|
||||||
should_display_ui = False
|
should_display_ui = False
|
||||||
|
|
||||||
# Running setup function
|
|
||||||
for extension, name in iterator():
|
for extension, name in iterator():
|
||||||
if hasattr(extension, "ui"):
|
if hasattr(extension, "ui"):
|
||||||
should_display_ui = True
|
should_display_ui = True
|
||||||
if extension not in setup_called and hasattr(extension, "setup"):
|
|
||||||
setup_called.add(extension)
|
|
||||||
extension.setup()
|
|
||||||
|
|
||||||
# Creating the extension ui elements
|
# Creating the extension ui elements
|
||||||
if should_display_ui:
|
if should_display_ui:
|
||||||
|
|||||||
@@ -164,10 +164,9 @@ def generate_instruct_html(history):
|
|||||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||||
|
|
||||||
# The time.time() is to prevent the brower from caching the image
|
# We use ?name2 and ?time.time() to force the browser to reset caches
|
||||||
suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
|
img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else ''
|
||||||
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
img_me = f'<img src="file/cache/pfp_me.png?{time.time() if reset_cache else ""}">' if Path("cache/pfp_me.png").exists() else ''
|
||||||
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
|
||||||
|
|
||||||
for i, _row in enumerate(history[::-1]):
|
for i, _row in enumerate(history[::-1]):
|
||||||
row = [convert_to_markdown(entry) for entry in _row]
|
row = [convert_to_markdown(entry) for entry in _row]
|
||||||
|
|||||||
176
modules/llama_attn_hijack.py
Normal file
176
modules/llama_attn_hijack.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
|
if shared.args.xformers:
|
||||||
|
try:
|
||||||
|
import xformers.ops
|
||||||
|
except Exception:
|
||||||
|
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_llama_attention():
|
||||||
|
if shared.args.xformers:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
|
print("Replaced attention with xformers_attention")
|
||||||
|
elif shared.args.sdp_attention:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||||
|
print("Replaced attention with sdp_attention")
|
||||||
|
|
||||||
|
|
||||||
|
def xformers_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
dtype = query_states.dtype
|
||||||
|
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
|
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||||
|
else:
|
||||||
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
|
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def sdp_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
#We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
|
if not output_attentions:
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import gc
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -13,14 +14,14 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
|||||||
BitsAndBytesConfig, LlamaTokenizer)
|
BitsAndBytesConfig, LlamaTokenizer)
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from modules import llama_attn_hijack
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
local_rank = None
|
|
||||||
|
|
||||||
if shared.args.flexgen:
|
if shared.args.flexgen:
|
||||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||||
|
|
||||||
|
local_rank = None
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
import deepspeed
|
import deepspeed
|
||||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||||
@@ -169,19 +170,46 @@ def load_model(model_name):
|
|||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||||
|
|
||||||
|
# Hijack attention with xformers
|
||||||
|
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||||
|
llama_attn_hijack.hijack_llama_attention()
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||||
elif type(model) is transformers.LlamaForCausalLM:
|
elif type(model) is transformers.LlamaForCausalLM:
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||||
|
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||||
|
# For some people this fixes things, for others it causes an error.
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||||
tokenizer.truncation_side = 'left'
|
|
||||||
|
|
||||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def clear_torch_cache():
|
||||||
|
gc.collect()
|
||||||
|
if not shared.args.cpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model():
|
||||||
|
shared.model = shared.tokenizer = None
|
||||||
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_model():
|
||||||
|
unload_model()
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def load_soft_prompt(name):
|
def load_soft_prompt(name):
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
shared.soft_prompt = False
|
shared.soft_prompt = False
|
||||||
|
|||||||
@@ -34,7 +34,13 @@ settings = {
|
|||||||
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
||||||
'greeting': 'Hello there!',
|
'greeting': 'Hello there!',
|
||||||
'end_of_turn': '',
|
'end_of_turn': '',
|
||||||
|
'custom_stopping_strings': '',
|
||||||
'stop_at_newline': False,
|
'stop_at_newline': False,
|
||||||
|
'add_bos_token': True,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'truncation_length': 2048,
|
||||||
|
'truncation_length_min': 0,
|
||||||
|
'truncation_length_max': 4096,
|
||||||
'chat_prompt_size': 2048,
|
'chat_prompt_size': 2048,
|
||||||
'chat_prompt_size_min': 0,
|
'chat_prompt_size_min': 0,
|
||||||
'chat_prompt_size_max': 2048,
|
'chat_prompt_size_max': 2048,
|
||||||
@@ -44,7 +50,7 @@ settings = {
|
|||||||
'default_extensions': [],
|
'default_extensions': [],
|
||||||
'chat_default_extensions': ["gallery"],
|
'chat_default_extensions': ["gallery"],
|
||||||
'presets': {
|
'presets': {
|
||||||
'default': 'NovelAI-Sphinx Moth',
|
'default': 'Default',
|
||||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||||
'.*pygmalion': 'NovelAI-Storywriter',
|
'.*pygmalion': 'NovelAI-Storywriter',
|
||||||
'.*RWKV': 'Naive',
|
'.*RWKV': 'Naive',
|
||||||
@@ -89,7 +95,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
|
|||||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||||
|
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||||
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
||||||
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
||||||
@@ -98,6 +104,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
|
|||||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
||||||
|
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
|
||||||
|
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
|
||||||
|
|
||||||
# llama.cpp
|
# llama.cpp
|
||||||
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import gc
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -12,28 +12,38 @@ from modules.callbacks import (Iteratorize, Stream,
|
|||||||
_SentinelTokenStoppingCriteria)
|
_SentinelTokenStoppingCriteria)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.models import local_rank
|
from modules.models import clear_torch_cache, local_rank
|
||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(tokens):
|
def get_max_prompt_length(state):
|
||||||
max_length = 2048 - tokens
|
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||||
return max_length
|
return max_length
|
||||||
|
|
||||||
|
|
||||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
input_ids = shared.tokenizer.encode(str(prompt))
|
input_ids = shared.tokenizer.encode(str(prompt))
|
||||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||||
return input_ids
|
return input_ids
|
||||||
else:
|
else:
|
||||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
|
||||||
|
|
||||||
|
# This is a hack for making replies more creative.
|
||||||
|
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||||
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
|
# Llama adds this extra token when the first character is '\n', and this
|
||||||
|
# compromises the stopping criteria, so we just remove it
|
||||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||||
input_ids = input_ids[:, 1:]
|
input_ids = input_ids[:, 1:]
|
||||||
|
|
||||||
if shared.args.cpu:
|
# Handling truncation
|
||||||
|
if truncation_length is not None:
|
||||||
|
input_ids = input_ids[:, -truncation_length:]
|
||||||
|
|
||||||
|
if any((shared.is_RWKV, shared.is_llamacpp, shared.args.cpu)):
|
||||||
return input_ids
|
return input_ids
|
||||||
elif shared.args.flexgen:
|
elif shared.args.flexgen:
|
||||||
return input_ids.numpy()
|
return input_ids.numpy()
|
||||||
@@ -63,9 +73,8 @@ def generate_softprompt_input_tensors(input_ids):
|
|||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
return inputs_embeds, filler_input_ids
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
|
|
||||||
|
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||||
@@ -73,9 +82,8 @@ def fix_gpt4chan(s):
|
|||||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
# Fix the LaTeX equations in galactica
|
||||||
|
|
||||||
|
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
s = s.replace(r'\]', r'$')
|
s = s.replace(r'\]', r'$')
|
||||||
@@ -101,48 +109,47 @@ def formatted_outputs(reply, model_name):
|
|||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def clear_torch_cache():
|
|
||||||
gc.collect()
|
|
||||||
if not shared.args.cpu:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def set_manual_seed(seed):
|
def set_manual_seed(seed):
|
||||||
if seed != -1:
|
seed = int(seed)
|
||||||
|
if seed == -1:
|
||||||
|
seed = random.randint(1, 2**31)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def stop_everything_event():
|
def stop_everything_event():
|
||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
set_manual_seed(generate_state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
original_question = question
|
original_question = question
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
question = apply_extensions(question, "input")
|
question = apply_extensions(question, 'input')
|
||||||
if shared.args.verbose:
|
|
||||||
print(f"\n\n{question}\n--------------------\n")
|
|
||||||
|
|
||||||
# These models are not part of Hugging Face, so we handle them
|
# These models are not part of Hugging Face, so we handle them
|
||||||
# separately and terminate the function call earlier
|
# separately and terminate the function call earlier
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{question}\n--------------------\n')
|
||||||
|
|
||||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params["token_count"] = generate_state["max_new_tokens"]
|
generate_params['token_count'] = state['max_new_tokens']
|
||||||
try:
|
try:
|
||||||
if shared.args.no_stream:
|
if shared.args.no_stream:
|
||||||
reply = shared.model.generate(context=question, **generate_params)
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
output = original_question + reply
|
output = original_question + reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
else:
|
else:
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
@@ -153,7 +160,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
output = original_question + reply
|
output = original_question + reply
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -162,47 +169,53 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(encode(original_question)[0])
|
original_tokens = len(encode(original_question)[0])
|
||||||
new_tokens = len(encode(output)[0]) - original_tokens
|
new_tokens = len(encode(output)[0]) - original_tokens
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|
||||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
|
|
||||||
|
if shared.args.verbose:
|
||||||
|
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
|
||||||
|
|
||||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
if eos_token is not None:
|
if eos_token is not None:
|
||||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
|
||||||
if type(stopping_strings) is list and len(stopping_strings) > 0:
|
|
||||||
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
|
||||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
|
||||||
|
|
||||||
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
|
# Handling the stopping strings
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
|
for st in [stopping_strings, state['custom_stopping_strings']]:
|
||||||
|
if type(st) is list and len(st) > 0:
|
||||||
|
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||||
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||||
|
break
|
||||||
|
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params["eos_token_id"] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params["stopping_criteria"] = stopping_criteria_list
|
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||||
if shared.args.no_stream:
|
if state['ban_eos_token']:
|
||||||
generate_params["min_length"] = 0
|
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||||
else:
|
else:
|
||||||
for k in ["do_sample", "temperature"]:
|
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params["stop"] = generate_state["eos_token_ids"][-1]
|
generate_params['stop'] = state['eos_token_ids'][-1]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
generate_params["max_new_tokens"] = 8
|
generate_params['max_new_tokens'] = 8
|
||||||
|
|
||||||
if shared.args.no_cache:
|
if shared.args.no_cache:
|
||||||
generate_params.update({"use_cache": False})
|
generate_params.update({'use_cache': False})
|
||||||
if shared.args.deepspeed:
|
if shared.args.deepspeed:
|
||||||
generate_params.update({"synced_gpus": True})
|
generate_params.update({'synced_gpus': True})
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
generate_params.update({"inputs": filler_input_ids})
|
generate_params.update({'inputs': filler_input_ids})
|
||||||
else:
|
else:
|
||||||
generate_params.update({"inputs": input_ids})
|
generate_params.update({'inputs': input_ids})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the entire reply at once.
|
# Generate the entire reply at once.
|
||||||
@@ -217,7 +230,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
new_tokens = len(output) - len(input_ids[0])
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
@@ -244,7 +257,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
new_tokens = len(output) - len(input_ids[0])
|
new_tokens = len(output) - len(input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
|
|
||||||
if output[-1] in eos_token_ids:
|
if output[-1] in eos_token_ids:
|
||||||
break
|
break
|
||||||
@@ -252,7 +265,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
|
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
else:
|
else:
|
||||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**generate_params)[0]
|
output = shared.model.generate(**generate_params)[0]
|
||||||
@@ -262,7 +275,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
new_tokens = len(output) - len(original_input_ids[0])
|
new_tokens = len(output) - len(original_input_ids[0])
|
||||||
reply = decode(output[-new_tokens:])
|
reply = decode(output[-new_tokens:])
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
reply = original_question + apply_extensions(reply, "output")
|
reply = original_question + apply_extensions(reply, 'output')
|
||||||
|
|
||||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||||
break
|
break
|
||||||
@@ -271,10 +284,10 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
generate_params.update({"inputs": filler_input_ids})
|
generate_params.update({'inputs': filler_input_ids})
|
||||||
else:
|
else:
|
||||||
generate_params.update({"inputs": input_ids})
|
generate_params.update({'inputs': input_ids})
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
@@ -284,5 +297,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(original_input_ids[0])
|
original_tokens = len(original_input_ids[0])
|
||||||
new_tokens = len(output) - original_tokens
|
new_tokens = len(output) - original_tokens
|
||||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
print("Loading raw text file dataset...")
|
print("Loading raw text file dataset...")
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||||
raw_text = file.read()
|
raw_text = file.read()
|
||||||
tokens = shared.tokenizer.encode(raw_text)
|
tokens = shared.tokenizer.encode(raw_text)
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
@@ -238,7 +238,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
warmup_steps=100,
|
warmup_steps=100,
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=epochs,
|
||||||
learning_rate=actual_lr,
|
learning_rate=actual_lr,
|
||||||
fp16=True,
|
fp16=False if shared.args.cpu else True,
|
||||||
logging_steps=20,
|
logging_steps=20,
|
||||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
@@ -248,7 +248,8 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
load_best_model_at_end=True if eval_data is not None else False,
|
load_best_model_at_end=True if eval_data is not None else False,
|
||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None
|
ddp_find_unused_parameters=None,
|
||||||
|
no_cuda=shared.args.cpu
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
callbacks=list([Callbacks()])
|
callbacks=list([Callbacks()])
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
do_sample=True
|
do_sample=True
|
||||||
top_p=0.5
|
top_p=0.95
|
||||||
top_k=40
|
top_k=50
|
||||||
temperature=0.7
|
temperature=1
|
||||||
repetition_penalty=1.2
|
repetition_penalty=1.2
|
||||||
typical_p=1.0
|
typical_p=1.0
|
||||||
early_stopping=False
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
accelerate==0.18.0
|
accelerate==0.18.0
|
||||||
bitsandbytes==0.37.2
|
|
||||||
datasets
|
datasets
|
||||||
flexgen==0.1.7
|
flexgen==0.1.7
|
||||||
gradio==3.24.1
|
gradio==3.24.1
|
||||||
markdown
|
markdown
|
||||||
numpy
|
numpy
|
||||||
|
Pillow>=9.5.0
|
||||||
peft==0.2.0
|
peft==0.2.0
|
||||||
requests
|
requests
|
||||||
rwkv==0.7.3
|
rwkv==0.7.3
|
||||||
@@ -13,3 +13,6 @@ sentencepiece
|
|||||||
pyyaml
|
pyyaml
|
||||||
tqdm
|
tqdm
|
||||||
git+https://github.com/huggingface/transformers
|
git+https://github.com/huggingface/transformers
|
||||||
|
bitsandbytes==0.37.2; platform_system != "Windows"
|
||||||
|
llama-cpp-python==0.1.30; platform_system != "Windows"
|
||||||
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
|
|||||||
320
server.py
320
server.py
@@ -2,11 +2,14 @@ import os
|
|||||||
|
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
|
import importlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import zipfile
|
import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -15,12 +18,12 @@ import gradio as gr
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import modules.extensions as extensions_module
|
import modules.extensions as extensions_module
|
||||||
from modules import chat, shared, training, ui, api
|
from modules import api, chat, shared, training, ui
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
from modules.LoRA import add_lora_to_model
|
from modules.LoRA import add_lora_to_model
|
||||||
from modules.models import load_model, load_soft_prompt
|
from modules.models import load_model, load_soft_prompt, unload_model
|
||||||
from modules.text_generation import (clear_torch_cache, generate_reply,
|
from modules.text_generation import generate_reply, stop_everything_event
|
||||||
stop_everything_event)
|
|
||||||
|
|
||||||
# Loading custom settings
|
# Loading custom settings
|
||||||
settings_file = None
|
settings_file = None
|
||||||
@@ -79,11 +82,6 @@ def get_available_loras():
|
|||||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||||
|
|
||||||
|
|
||||||
def unload_model():
|
|
||||||
shared.model = shared.tokenizer = None
|
|
||||||
clear_torch_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_wrapper(selected_model):
|
def load_model_wrapper(selected_model):
|
||||||
if selected_model != shared.model_name:
|
if selected_model != shared.model_name:
|
||||||
shared.model_name = selected_model
|
shared.model_name = selected_model
|
||||||
@@ -178,6 +176,34 @@ def create_prompt_menus():
|
|||||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model_wrapper(repo_id):
|
||||||
|
try:
|
||||||
|
downloader = importlib.import_module("download-model")
|
||||||
|
|
||||||
|
model = repo_id
|
||||||
|
branch = "main"
|
||||||
|
check = False
|
||||||
|
|
||||||
|
yield ("Cleaning up the model/branch names")
|
||||||
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
||||||
|
yield ("Getting the download links from Hugging Face")
|
||||||
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
|
|
||||||
|
yield ("Getting the output folder")
|
||||||
|
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||||
|
|
||||||
|
if check:
|
||||||
|
yield ("Checking previously downloaded files")
|
||||||
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
|
else:
|
||||||
|
yield (f"Downloading files to {output_folder}")
|
||||||
|
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||||
|
yield ("Done!")
|
||||||
|
except:
|
||||||
|
yield traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
def create_model_menus():
|
def create_model_menus():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -188,16 +214,26 @@ def create_model_menus():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
|
||||||
|
info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
|
||||||
|
with gr.Column():
|
||||||
|
shared.gradio['download_button'] = gr.Button("Download")
|
||||||
|
shared.gradio['download_status'] = gr.Markdown()
|
||||||
|
with gr.Column():
|
||||||
|
pass
|
||||||
|
|
||||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||||
|
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
|
||||||
|
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
|
|
||||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
|
||||||
generate_params[k] = shared.settings[k]
|
|
||||||
shared.gradio['generate_state'] = gr.State(generate_params)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -210,24 +246,24 @@ def create_settings_menus(default_preset):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
|
||||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
|
||||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
|
||||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
|
||||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
|
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
|
||||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
|
||||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
|
||||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
gr.Markdown('Contrastive search')
|
gr.Markdown('Contrastive search')
|
||||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||||
with gr.Box():
|
|
||||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -236,6 +272,13 @@ def create_settings_menus(default_preset):
|
|||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
with gr.Row():
|
||||||
|
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||||
|
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='This forces the model to never end the generation prematurely.')
|
||||||
|
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||||
|
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
|
||||||
|
|
||||||
with gr.Accordion('Soft prompt', open=False):
|
with gr.Accordion('Soft prompt', open=False):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
||||||
@@ -245,7 +288,7 @@ def create_settings_menus(default_preset):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||||
|
|
||||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||||
|
|
||||||
@@ -318,6 +361,21 @@ else:
|
|||||||
title = 'Text generation web UI'
|
title = 'Text generation web UI'
|
||||||
|
|
||||||
|
|
||||||
|
def list_interface_input_elements(chat=False):
|
||||||
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||||
|
if chat:
|
||||||
|
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
def gather_interface_values(*args):
|
||||||
|
output = {}
|
||||||
|
for i, element in enumerate(shared.input_elements):
|
||||||
|
output[element] = args[i]
|
||||||
|
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
gen_events = []
|
gen_events = []
|
||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
@@ -325,27 +383,34 @@ def create_interface():
|
|||||||
|
|
||||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||||
if shared.is_chat():
|
if shared.is_chat():
|
||||||
|
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=True)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
|
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Generate'] = gr.Button('Generate')
|
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
|
||||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||||
|
shared.gradio['Continue'] = gr.Button('Continue')
|
||||||
|
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
|
||||||
|
shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
|
||||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||||
shared.gradio['Remove last'] = gr.Button('Remove last')
|
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||||
|
with gr.Row():
|
||||||
shared.gradio['Clear history'] = gr.Button('Clear history')
|
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||||
|
shared.gradio['Remove last'] = gr.Button('Remove last')
|
||||||
|
|
||||||
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
|
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
|
||||||
|
|
||||||
with gr.Tab("Character", elem_id="chat-settings"):
|
with gr.Tab("Character", elem_id="chat-settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -392,66 +457,102 @@ def create_interface():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||||
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
|
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
|
||||||
|
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
|
||||||
|
|
||||||
def set_chat_input(textbox):
|
|
||||||
return textbox, ""
|
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
|
||||||
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
|
||||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
|
||||||
|
|
||||||
# Clear history with confirmation
|
|
||||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||||
|
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Impersonate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Replace last reply'].click(
|
||||||
|
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Send dummy message'].click(
|
||||||
|
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Send dummy reply'].click(
|
||||||
|
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Clear history-confirm'].click(
|
||||||
|
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||||
|
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
|
||||||
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(
|
||||||
|
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['mode'].change(
|
||||||
|
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
|
||||||
|
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['Instruction templates'].change(
|
||||||
|
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['upload_chat_history'].upload(
|
||||||
|
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||||
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
|
||||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
|
|
||||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
|
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||||
|
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
|
||||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
|
||||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
|
||||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||||
|
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||||
# Clearing stuff and saving the history
|
|
||||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
|
||||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
|
||||||
shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
|
||||||
shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
|
||||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
|
||||||
shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
|
||||||
|
|
||||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
|
||||||
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
|
||||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
|
|
||||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
|
|
||||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
|
||||||
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
|
||||||
|
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||||
shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=False)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
@@ -479,14 +580,27 @@ def create_interface():
|
|||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=False)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@@ -512,12 +626,28 @@ def create_interface():
|
|||||||
with gr.Tab("Parameters", elem_id="parameters"):
|
with gr.Tab("Parameters", elem_id="parameters"):
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
|
||||||
|
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
|
)
|
||||||
|
|
||||||
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
with gr.Tab("Model", elem_id="model-tab"):
|
with gr.Tab("Model", elem_id="model-tab"):
|
||||||
@@ -541,26 +671,16 @@ def create_interface():
|
|||||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
||||||
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
|
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
|
||||||
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
|
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
|
||||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
|
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
|
||||||
|
|
||||||
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
|
# Reset interface event
|
||||||
shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
shared.gradio['reset_interface'].click(
|
||||||
|
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
|
||||||
|
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||||
|
|
||||||
if shared.args.extensions is not None:
|
if shared.args.extensions is not None:
|
||||||
extensions_module.create_extensions_block()
|
extensions_module.create_extensions_block()
|
||||||
|
|
||||||
def change_dict_value(d, key, value):
|
|
||||||
d[key] = value
|
|
||||||
return d
|
|
||||||
|
|
||||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
|
||||||
if k not in shared.gradio:
|
|
||||||
continue
|
|
||||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
|
||||||
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
|
||||||
else:
|
|
||||||
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
|
||||||
|
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
api.create_apis()
|
api.create_apis()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,14 @@
|
|||||||
"name2": "Assistant",
|
"name2": "Assistant",
|
||||||
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
|
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
|
||||||
"greeting": "Hello there!",
|
"greeting": "Hello there!",
|
||||||
|
"end_of_turn": "",
|
||||||
|
"custom_stopping_strings": "",
|
||||||
"stop_at_newline": false,
|
"stop_at_newline": false,
|
||||||
|
"add_bos_token": true,
|
||||||
|
"ban_eos_token": true,
|
||||||
|
"truncation_length": 2048,
|
||||||
|
"truncation_length_min": 0,
|
||||||
|
"truncation_length_max": 4096,
|
||||||
"chat_prompt_size": 2048,
|
"chat_prompt_size": 2048,
|
||||||
"chat_prompt_size_min": 0,
|
"chat_prompt_size_min": 0,
|
||||||
"chat_prompt_size_max": 2048,
|
"chat_prompt_size_max": 2048,
|
||||||
@@ -19,7 +26,8 @@
|
|||||||
"gallery"
|
"gallery"
|
||||||
],
|
],
|
||||||
"presets": {
|
"presets": {
|
||||||
"default": "NovelAI-Sphinx Moth",
|
"default": "Default",
|
||||||
|
".*(alpaca|llama)": "LLaMA-Precise",
|
||||||
".*pygmalion": "NovelAI-Storywriter",
|
".*pygmalion": "NovelAI-Storywriter",
|
||||||
".*RWKV": "Naive"
|
".*RWKV": "Naive"
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user