2 Commits

Author SHA1 Message Date
oobabooga
f57ebb6c42 Revert unintended changes 2023-04-07 00:13:14 -03:00
oobabooga
01cacfc14f Reformat everything 2023-04-07 00:08:46 -03:00
39 changed files with 646 additions and 1625 deletions

View File

@@ -1,6 +1,7 @@
.env .env
Dockerfile Dockerfile
/characters /characters
/extensions
/loras /loras
/models /models
/presets /presets

View File

@@ -26,11 +26,12 @@ LABEL maintainer="Your Name <your.email@example.com>"
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI" LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
RUN apt-get update && \ RUN apt-get update && \
apt-get install --no-install-recommends -y git python3 python3-pip make g++ && \ apt-get install --no-install-recommends -y git python3 python3-pip && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
RUN mkdir /app
COPY . /app/
WORKDIR /app WORKDIR /app
@@ -40,29 +41,21 @@ RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Usi
RUN virtualenv /app/venv RUN virtualenv /app/venv
RUN . /app/venv/bin/activate && \ RUN . /app/venv/bin/activate && \
pip3 install --upgrade pip setuptools && \ pip3 install --upgrade pip setuptools && \
pip3 install torch torchvision torchaudio pip3 install torch torchvision torchaudio && \
pip3 install -r requirements.txt
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
RUN . /app/venv/bin/activate && \ RUN . /app/venv/bin/activate && \
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
COPY extensions/api/requirements.txt /app/extensions/api/requirements.txt ENV CLI_ARGS=""
COPY extensions/elevenlabs_tts/requirements.txt /app/extensions/elevenlabs_tts/requirements.txt
COPY extensions/google_translate/requirements.txt /app/extensions/google_translate/requirements.txt
COPY extensions/silero_tts/requirements.txt /app/extensions/silero_tts/requirements.txt
COPY extensions/whisper_stt/requirements.txt /app/extensions/whisper_stt/requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
COPY requirements.txt /app/requirements.txt
RUN . /app/venv/bin/activate && \
pip3 install -r requirements.txt
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
COPY . /app/
ENV CLI_ARGS=""
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS} CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}

127
README.md
View File

@@ -1,9 +1,11 @@
# Text generation web UI # Text generation web UI
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA. A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, OPT, and GALACTICA.
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation. Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
[[Try it on Google Colab]](https://colab.research.google.com/github/oobabooga/AI-Notebooks/blob/main/Colab-TextGen-GPU.ipynb)
|![Image1](https://github.com/oobabooga/screenshots/raw/main/qa.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/cai3.png) | |![Image1](https://github.com/oobabooga/screenshots/raw/main/qa.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/cai3.png) |
|:---:|:---:| |:---:|:---:|
|![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) | ![Image4](https://github.com/oobabooga/screenshots/raw/main/galactica.png) | |![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) | ![Image4](https://github.com/oobabooga/screenshots/raw/main/galactica.png) |
@@ -13,7 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* Dropdown menu for switching between models * Dropdown menu for switching between models
* Notebook mode that resembles OpenAI's playground * Notebook mode that resembles OpenAI's playground
* Chat mode for conversation and role playing * Chat mode for conversation and role playing
* Instruct mode compatible with Alpaca, Vicuna, and Open Assistant formats **\*NEW!\*** * Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
* Nice HTML output for GPT-4chan * Nice HTML output for GPT-4chan
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering * Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters) * [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
@@ -32,6 +34,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs) * [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
* Softprompts * Softprompts
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions) * [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
## Installation ## Installation
@@ -70,15 +73,9 @@ On Linux or WSL, it can be automatically installed with these two commands:
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh" curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
bash Miniconda3.sh bash Miniconda3.sh
``` ```
Source: https://educe-ubc.github.io/conda.html Source: https://educe-ubc.github.io/conda.html
#### 0.1 (Ubuntu/WSL) Install build tools
```
sudo apt install build-essential
```
#### 1. Create a new conda environment #### 1. Create a new conda environment
``` ```
@@ -122,7 +119,7 @@ As an alternative to the recommended WSL method, you can install the web UI nati
``` ```
cp .env.example .env cp .env.example .env
docker compose up --build docker-compose up --build
``` ```
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU. Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
@@ -194,84 +191,82 @@ Optionally, you can use the following command-line flags:
#### Basic settings #### Basic settings
| Flag | Description | | Flag | Description |
|--------------------------------------------|-------------| |------------------|-------------|
| `-h`, `--help` | Show this help message and exit. | | `-h`, `--help` | show this help message and exit |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode. | | `--chat` | Launch the web UI in chat mode.|
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--model-dir MODEL_DIR` | Path to directory with all the models. | | `--model-dir MODEL_DIR` | Path to directory with all the models |
| `--lora-dir LORA_DIR` | Path to directory with all the loras. | | `--lora-dir LORA_DIR` | Path to directory with all the loras |
| `--no-stream` | Don't stream the text output in real time. | | `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag. | | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--verbose` | Print the prompts to the terminal. | | `--verbose` | Print the prompts to the terminal. |
#### Accelerate/transformers #### Accelerate/transformers
| Flag | Description | | Flag | Description |
|---------------------------------------------|-------------| |------------------|-------------|
| `--cpu` | Use the CPU to generate text. Warning: Training on CPU is extremely slow.| | `--cpu` | Use the CPU to generate text.|
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. | | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. | | `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.| | `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | | `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. | | `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. | | `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
| `--sdp-attention` | Use torch 2.0's sdp attention. |
#### llama.cpp #### llama.cpp
| Flag | Description | | Flag | Description |
|-------------|-------------| |------------------|-------------|
| `--threads` | Number of threads to use in llama.cpp. | | `--threads` | Number of threads to use in llama.cpp. |
#### GPTQ #### GPTQ
| Flag | Description | | Flag | Description |
|---------------------------|-------------| |------------------|-------------|
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. | | `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. | | `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
| `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--groupsize GROUPSIZE` | GPTQ: Group size. |
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
#### FlexGen #### FlexGen
| Flag | Description | | Flag | Description |
|------------------|-------------| |------------------|-------------|
| `--flexgen` | Enable the use of FlexGen offloading. | | `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). | | `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).| | `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). | | `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
#### DeepSpeed #### DeepSpeed
| Flag | Description | | Flag | Description |
|---------------------------------------|-------------| |------------------|-------------|
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. | | `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. | | `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. | | `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
#### RWKV #### RWKV
| Flag | Description | | Flag | Description |
|---------------------------------|-------------| |------------------|-------------|
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | | `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
#### Gradio #### Gradio
| Flag | Description | | Flag | Description |
|---------------------------------------|-------------| |------------------|-------------|
| `--listen` | Make the web UI reachable from your local network. | | `--listen` | Make the web UI reachable from your local network. |
| `--listen-port LISTEN_PORT` | The listening port that the server will use. | | `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. | | `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | | `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
@@ -289,9 +284,7 @@ Check the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/System-
## Contributing ## Contributing
Pull requests, suggestions, and issue reports are welcome. Pull requests, suggestions, and issue reports are welcome.
You are also welcome to review open pull requests.
Before reporting a bug, make sure that you have: Before reporting a bug, make sure that you have:

View File

@@ -12,11 +12,6 @@ import string
import websockets import websockets
# Note, Gradio may pick a different fn value as the definition of the Gradio app changes.
# You can always launch the web UI and inspect the websocket stream using your browser's dev tools
# to determine what value Gradio expects here.
GRADIO_FN = 29
def random_hash(): def random_hash():
letters = string.ascii_lowercase + string.digits letters = string.ascii_lowercase + string.digits
@@ -41,10 +36,6 @@ async def run(context):
'length_penalty': 1, 'length_penalty': 1,
'early_stopping': False, 'early_stopping': False,
'seed': -1, 'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
'custom_stopping_strings': [],
'ban_eos_token': False
} }
payload = json.dumps([context, params]) payload = json.dumps([context, params])
session = random_hash() session = random_hash()
@@ -56,14 +47,14 @@ async def run(context):
case "send_hash": case "send_hash":
await websocket.send(json.dumps({ await websocket.send(json.dumps({
"session_hash": session, "session_hash": session,
"fn_index": GRADIO_FN "fn_index": 12
})) }))
case "estimation": case "estimation":
pass pass
case "send_data": case "send_data":
await websocket.send(json.dumps({ await websocket.send(json.dumps({
"session_hash": session, "session_hash": session,
"fn_index": GRADIO_FN, "fn_index": 12,
"data": [ "data": [
payload payload
] ]

View File

@@ -22,10 +22,10 @@ server = "127.0.0.1"
params = { params = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'do_sample': True, 'do_sample': True,
'temperature': 0.72, 'temperature': 0.5,
'top_p': 0.73, 'top_p': 0.9,
'typical_p': 1, 'typical_p': 1,
'repetition_penalty': 1.1, 'repetition_penalty': 1.05,
'encoder_repetition_penalty': 1.0, 'encoder_repetition_penalty': 1.0,
'top_k': 0, 'top_k': 0,
'min_length': 0, 'min_length': 0,
@@ -35,10 +35,6 @@ params = {
'length_penalty': 1, 'length_penalty': 1,
'early_stopping': False, 'early_stopping': False,
'seed': -1, 'seed': -1,
'add_bos_token': True,
'custom_stopping_strings': [],
'truncation_length': 2048,
'ban_eos_token': False,
} }
# Input prompt # Input prompt

View File

@@ -36,8 +36,3 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
.wrap.svelte-6roggh.svelte-6roggh { .wrap.svelte-6roggh.svelte-6roggh {
max-height: 92.5%; max-height: 92.5%;
} }
/* This is for the microphone button in the whisper extension */
.sm.svelte-1ipelgc {
width: 100%;
}

View File

@@ -7,13 +7,11 @@
padding-right: 20px; padding-right: 20px;
display: flex; display: flex;
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word;
overflow-wrap: anywhere;
} }
.message { .message {
display: grid; display: grid;
grid-template-columns: 60px minmax(0, 1fr); grid-template-columns: 60px 1fr;
padding-bottom: 25px; padding-bottom: 25px;
font-size: 15px; font-size: 15px;
font-family: Helvetica, Arial, sans-serif; font-family: Helvetica, Arial, sans-serif;
@@ -75,13 +73,6 @@
display: inline !important; display: inline !important;
} }
.message-body code {
overflow-x: auto;
}
.message-body :not(pre) > code {
white-space: normal !important;
}
.dark .message-body p em { .dark .message-body p em {
color: rgb(138, 138, 138) !important; color: rgb(138, 138, 138) !important;
} }

View File

@@ -7,8 +7,6 @@
padding-right: 20px; padding-right: 20px;
display: flex; display: flex;
flex-direction: column-reverse; flex-direction: column-reverse;
word-break: break-word;
overflow-wrap: anywhere;
} }
.message { .message {
@@ -27,7 +25,9 @@
.message-body {} .message-body {}
.message-body p { .message-body p {
margin-bottom: 0 !important;
font-size: 15px !important; font-size: 15px !important;
line-height: 1.428571429 !important;
} }
.message-body li { .message-body li {
@@ -39,13 +39,6 @@
display: inline !important; display: inline !important;
} }
.message-body code {
overflow-x: auto;
}
.message-body :not(pre) > code {
white-space: normal !important;
}
.dark .message-body p em { .dark .message-body p em {
color: rgb(138, 138, 138) !important; color: rgb(138, 138, 138) !important;
} }
@@ -58,16 +51,15 @@
padding: 15px; padding: 15px;
border-radius: 20px; border-radius: 20px;
background-color: #0000000f; background-color: #0000000f;
margin-top: 9px !important; margin-bottom: 17.5px;
margin-bottom: 18px !important;
} }
.gradio-container .chat .user-message { .gradio-container .chat .user-message {
padding: 15px; padding: 15px;
border-radius: 20px; border-radius: 20px;
margin-bottom: 9px !important; margin-bottom: 17.5px !important;
} }
.dark .chat .assistant-message { .dark .chat .assistant-message {
background-color: #374151; background-color: #ffffff21;
} }

View File

@@ -67,13 +67,3 @@ span.math.inline {
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
flex-wrap: nowrap; flex-wrap: nowrap;
} }
.header_bar {
background-color: #f7f7f7;
margin-bottom: 40px;
}
.dark .header_bar {
border: none !important;
background-color: #8080802b;
}

View File

@@ -1,4 +1,4 @@
document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar"); document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
document.getElementById("main").parentNode.style = "padding: 0; margin: 0"; document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"; document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";

View File

@@ -6,6 +6,7 @@ services:
args: args:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus # specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST} TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
GPTQ_VERSION: ${GPTQ_VERSION}
WEBUI_VERSION: ${WEBUI_VERSION} WEBUI_VERSION: ${WEBUI_VERSION}
env_file: .env env_file: .env
ports: ports:

View File

@@ -19,6 +19,50 @@ import requests
import tqdm import tqdm
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
def get_file(url, output_folder):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not args.clean:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
else:
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name):
return branch_name
else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
def select_model_from_default_options(): def select_model_from_default_options():
models = { models = {
@@ -66,20 +110,7 @@ EleutherAI/pythia-1.4b-deduped
return model, branch return model, branch
def sanitize_model_and_branch_names(model, branch): def get_download_links_from_huggingface(model, branch):
if model[-1] == '/':
model = model[:-1]
if branch is None:
branch = "main"
else:
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if not pattern.match(branch):
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
return model, branch
def get_download_links_from_huggingface(model, branch, text_only=False):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b"" cursor = b""
@@ -111,14 +142,14 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
is_tokenizer = re.match("tokenizer.*\.model", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)): if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
if 'lfs' in dict[i]: if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']]) sha256.append([fname, dict[i]['lfs']['oid']])
if is_text: if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text') classifications.append('text')
continue continue
if not text_only: if not args.text_only:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
if is_safetensors: if is_safetensors:
has_safetensors = True has_safetensors = True
@@ -146,125 +177,80 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
return links, sha256, is_lora return links, sha256, is_lora
def get_output_folder(model, branch, is_lora, base_folder=None): def download_files(file_list, output_folder, num_threads=8):
if base_folder is None: thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
if __name__ == '__main__':
model = args.MODEL
branch = args.branch
if model is None:
model, branch = select_model_from_default_options()
else:
if model[-1] == '/':
model = model[:-1]
branch = args.branch
if branch is None:
branch = "main"
else:
try:
branch = sanitize_branch_name(branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
if args.output is not None:
base_folder = args.output
else:
base_folder = 'models' if not is_lora else 'loras' base_folder = 'models' if not is_lora else 'loras'
output_folder = f"{'_'.join(model.split('/')[-2:])}" output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main': if branch != 'main':
output_folder += f'_{branch}' output_folder += f'_{branch}'
output_folder = Path(base_folder) / output_folder output_folder = Path(base_folder) / output_folder
return output_folder
def get_single_file(url, output_folder, start_from_scratch=False):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
else:
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
# Creating the folder and writing the metadata
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
def check_model_files(model, branch, links, sha256, output_folder):
# Validate the checksums
validated = True
for i in range(len(sha256)):
fpath = (output_folder / sha256[i][0])
if not fpath.exists():
print(f"The following file is missing: {fpath}")
validated = False
continue
with open(output_folder / sha256[i][0], "rb") as f:
bytes = f.read()
file_hash = hashlib.sha256(bytes).hexdigest()
if file_hash != sha256[i][1]:
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
validated = False
else:
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
if validated:
print('[+] Validated checksums of all model files!')
else:
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
branch = args.branch
model = args.MODEL
if model is None:
model, branch = select_model_from_default_options()
# Cleaning up the model/branch names
try:
model, branch = sanitize_model_and_branch_names(model, branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
# Getting the download links from Hugging Face
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
# Getting the output folder
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
if args.check: if args.check:
# Check previously downloaded files # Validate the checksums
check_model_files(model, branch, links, sha256, output_folder) validated = True
for i in range(len(sha256)):
fpath = (output_folder / sha256[i][0])
if not fpath.exists():
print(f"The following file is missing: {fpath}")
validated = False
continue
with open(output_folder / sha256[i][0], "rb") as f:
bytes = f.read()
file_hash = hashlib.sha256(bytes).hexdigest()
if file_hash != sha256[i][1]:
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
validated = False
else:
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
if validated:
print('[+] Validated checksums of all model files!')
else:
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
else: else:
# Download files
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads) # Creating the folder and writing the metadata
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
download_files(links, output_folder, args.threads)

View File

@@ -57,15 +57,12 @@ class Handler(BaseHTTPRequestHandler):
'length_penalty': float(body.get('length_penalty', 1)), 'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)), 'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)), 'seed': int(body.get('seed', -1)),
'add_bos_token': int(body.get('add_bos_token', True)),
'custom_stopping_strings': body.get('custom_stopping_strings', []),
'truncation_length': int(body.get('truncation_length', 2048)),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
} }
generator = generate_reply( generator = generate_reply(
prompt, prompt,
generate_params, generate_params,
stopping_strings=body.get('stopping_strings', []),
) )
answer = '' answer = ''
@@ -81,19 +78,6 @@ class Handler(BaseHTTPRequestHandler):
}] }]
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
# Not compatible with KoboldAI api
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else: else:
self.send_error(404) self.send_error(404)

View File

@@ -1,23 +1,8 @@
import gradio as gr import gradio as gr
import os
# get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))
# check if the bias_options.txt file exists, if not, create it
bias_file = os.path.join(current_dir, "bias_options.txt")
if not os.path.isfile(bias_file):
with open(bias_file, "w") as f:
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
# read bias options from the text file
with open(bias_file, "r") as f:
bias_options = [line.strip() for line in f.readlines()]
params = { params = {
"activate": True, "activate": True,
"bias string": " *I am so happy*", "bias string": " *I am so happy*",
"use custom string": False,
} }
@@ -26,6 +11,7 @@ def input_modifier(string):
This function is applied to your text inputs before This function is applied to your text inputs before
they are fed into the model. they are fed into the model.
""" """
return string return string
@@ -33,6 +19,7 @@ def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
return string return string
@@ -42,11 +29,9 @@ def bot_prefix_modifier(string):
the prefix text for the Bot and can be used to bias its the prefix text for the Bot and can be used to bias its
behavior. behavior.
""" """
if params['activate']: if params['activate']:
if params['use custom string']: return f'{string} {params["bias string"].strip()} '
return f'{string} {params["custom string"].strip()} '
else:
return f'{string} {params["bias string"].strip()} '
else: else:
return string return string
@@ -54,29 +39,8 @@ def bot_prefix_modifier(string):
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias') activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file') string = gr.Textbox(value=params["bias string"], label='Character bias')
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
def update_bias_string(x): string.change(lambda x: params.update({"bias string": x}), string, None)
if x:
params.update({"bias string": x})
else:
params.update({"bias string": dropdown_string.get()})
return x
def update_custom_string(x):
params.update({"custom string": x})
dropdown_string.change(update_bias_string, dropdown_string, None)
custom_string.change(update_custom_string, custom_string, None)
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
# Group elements together depending on the selected option
def bias_string_group():
if use_custom_string.value:
return gr.Group([use_custom_string, custom_string])
else:
return dropdown_string

View File

@@ -2,11 +2,10 @@ import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.shared as shared
from elevenlabslib import ElevenLabsUser from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path from elevenlabslib.helpers import save_bytes_to_path
import modules.shared as shared
params = { params = {
'activate': True, 'activate': True,
'api_key': '12345', 'api_key': '12345',

View File

@@ -87,10 +87,10 @@ def ui():
update = gr.Button("Refresh") update = gr.Button("Refresh")
gr.HTML(value="<style>" + generate_css() + "</style>") gr.HTML(value="<style>" + generate_css() + "</style>")
gallery = gr.Dataset(components=[gr.HTML(visible=False)], gallery = gr.Dataset(components=[gr.HTML(visible=False)],
label="", label="",
samples=generate_html(), samples=generate_html(),
elem_classes=["character-gallery"], elem_classes=["character-gallery"],
samples_per_page=50 samples_per_page=50
) )
update.click(generate_html, [], gallery) update.click(generate_html, [], gallery)
gallery.select(select_character, None, gradio['character_menu']) gallery.select(select_character, None, gradio['character_menu'])

View File

@@ -1,7 +1,6 @@
import gradio as gr import gradio as gr
import pandas as pd
import modules.shared as shared import modules.shared as shared
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")

View File

@@ -1,78 +0,0 @@
## Description:
TL;DR: Lets the bot answer you with a picture!
Stable Diffusion API pictures for TextGen, v.1.1.0
An extension to [oobabooga's textgen-webui](https://github.com/oobabooga/text-generation-webui) allowing you to receive pics generated by [Automatic1111's SD-WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
<details>
<summary>Interface overview</summary>
![Interface](https://raw.githubusercontent.com/Brawlence/texgen-webui-SD_api_pics/main/illust/Interface.jpg)
</details>
Load it in the `--chat` mode with `--extension sd_api_pictures` alongside `send_pictures` (it's not really required, but completes the picture, *pun intended*).
The image generation is triggered either:
- manually through the 'Force the picture response' button while in `Manual` or `Immersive/Interactive` modes OR
- automatically in `Immersive/Interactive` mode if the words `'send|main|message|me'` are followed by `'image|pic|picture|photo|snap|snapshot|selfie|meme'` in the user's prompt
- always on in Picturebook/Adventure mode (if not currently suppressed by 'Suppress the picture response')
## Prerequisites
One needs an available instance of Automatic1111's webui running with an `--api` flag. Ain't tested with a notebook / cloud hosted one but should be possible.
To run it locally in parallel on the same machine, specify custom `--listen-port` for either Auto1111's or ooba's webUIs.
## Features:
- API detection (press enter in the API box)
- VRAM management (model shuffling)
- Three different operation modes (manual, interactive, always-on)
- persistent settings via settings.json
The model input is modified only in the interactive mode; other two are unaffected. The output pic description is presented differently for Picture-book / Adventure mode.
Connection check (insert the Auto1111's address and press Enter):
![API-check](https://raw.githubusercontent.com/Brawlence/texgen-webui-SD_api_pics/main/illust/API-check.gif)
### Persistents settings
Create or modify the `settings.json` in the `text-generation-webui` root directory to override the defaults
present in script.py, ex:
```json
{
"sd_api_pictures-manage_VRAM": 1,
"sd_api_pictures-save_img": 1,
"sd_api_pictures-prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful, (solo:1.1)",
"sd_api_pictures-sampler_name": "DPM++ 2M Karras"
}
```
will automatically set the `Manage VRAM` & `Keep original images` checkboxes and change the texts in `Prompt Prefix` and `Sampler name` on load.
---
## Demonstrations:
Those are examples of the version 1.0.0, but the core functionality is still the same
<details>
<summary>Conversation 1</summary>
![EXA1](https://user-images.githubusercontent.com/42910943/224866564-939a3bcb-e7cf-4ac0-a33f-b3047b55054d.jpg)
![EXA2](https://user-images.githubusercontent.com/42910943/224866566-38394054-1320-45cf-9515-afa76d9d7745.jpg)
![EXA3](https://user-images.githubusercontent.com/42910943/224866568-10ea47b7-0bac-4269-9ec9-22c387a13b59.jpg)
![EXA4](https://user-images.githubusercontent.com/42910943/224866569-326121ad-1ea1-4874-9f6b-4bca7930a263.jpg)
</details>
<details>
<summary>Conversation 2</summary>
![Hist1](https://user-images.githubusercontent.com/42910943/224865517-c6966b58-bc4d-4353-aab9-6eb97778d7bf.jpg)
![Hist2](https://user-images.githubusercontent.com/42910943/224865527-b2fe7c2e-0da5-4c2e-b705-42e233b07084.jpg)
![Hist3](https://user-images.githubusercontent.com/42910943/224865535-a38d94e7-8975-4a46-a655-1ae1de41f85d.jpg)
</details>

View File

@@ -1,78 +1,34 @@
import base64 import base64
import io import io
import re import re
import time
from datetime import date
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.chat as chat
import modules.shared as shared import modules.shared as shared
import requests import requests
import torch import torch
from modules.models import reload_model, unload_model
from PIL import Image from PIL import Image
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui # parameters which can be customized in settings.json of webui
params = { params = {
'enable_SD_api': False,
'address': 'http://127.0.0.1:7860', 'address': 'http://127.0.0.1:7860',
'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
'manage_VRAM': False,
'save_img': False, 'save_img': False,
'SD_model': 'NeverEndingDream', # not used right now 'SD_model': 'NeverEndingDream', # not really used right now
'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful', 'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
'negative_prompt': '(worst quality, low quality:1.3)', 'negative_prompt': '(worst quality, low quality:1.3)',
'width': 512, 'side_length': 512,
'height': 512, 'restore_faces': False
'restore_faces': False,
'seed': -1,
'sampler_name': 'DDIM',
'steps': 32,
'cfg_scale': 7
} }
def give_VRAM_priority(actor):
global shared, params
if actor == 'SD':
unload_model()
print("Requesting Auto1111 to re-load last checkpoint used...")
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
response.raise_for_status()
elif actor == 'LLM':
print("Requesting Auto1111 to vacate VRAM...")
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
response.raise_for_status()
reload_model()
elif actor == 'set':
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
response.raise_for_status()
elif actor == 'reset':
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
response.raise_for_status()
else:
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
response.raise_for_status()
del response
if params['manage_VRAM']:
give_VRAM_priority('set')
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
streaming_state = shared.args.no_stream # remember if chat streaming was enabled streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
@@ -80,13 +36,7 @@ def remove_surrounded_chars(string):
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def triggers_are_in(string):
string = remove_surrounded_chars(string)
# regex searches for send|main|message|me (at the end of the word) followed by
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
# (?aims) are regex parser flags
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
def input_modifier(string): def input_modifier(string):
@@ -94,80 +44,75 @@ def input_modifier(string):
This function is applied to your text inputs before This function is applied to your text inputs before
they are fed into the model. they are fed into the model.
""" """
global params, picture_response
global params if not params['enable_SD_api']:
if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
return string return string
if triggers_are_in(string): # if we're in it, check for trigger words commands = ['send', 'mail', 'me']
toggle_generation(True) mediums = ['image', 'pic', 'picture', 'photo']
string = string.lower() subjects = ['yourself', 'own']
if "of" in string: lowstr = string.lower()
subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
string = "Please provide a detailed and vivid description of " + subject # TODO: refactor out to separate handler and also replace detection with a regexp
else: if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now" picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*"
string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
string = "Please provide a detailed and vivid description of how you look and what you are wearing"
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params global params, pic_id
if params['manage_VRAM']:
give_VRAM_priority('SD')
payload = { payload = {
"prompt": params['prompt_prefix'] + description, "prompt": params['prompt_prefix'] + description,
"seed": params['seed'], "seed": -1,
"sampler_name": params['sampler_name'], "sampler_name": "DPM++ 2M Karras",
"steps": params['steps'], "steps": 32,
"cfg_scale": params['cfg_scale'], "cfg_scale": 7,
"width": params['width'], "width": params['side_length'],
"height": params['height'], "height": params['side_length'],
"restore_faces": params['restore_faces'], "restore_faces": params['restore_faces'],
"negative_prompt": params['negative_prompt'] "negative_prompt": params['negative_prompt']
} }
print(f'Prompting the image generator via the API on {params["address"]}...')
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
response.raise_for_status()
r = response.json() r = response.json()
visible_result = "" visible_result = ""
for img_str in r['images']: for img_str in r['images']:
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
if params['save_img']: if params['save_img']:
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}' output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
output_file.parent.mkdir(parents=True, exist_ok=True)
image.save(output_file.as_posix()) image.save(output_file.as_posix())
visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n' pic_id += 1
else: # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history image.thumbnail((300, 300))
image.thumbnail((300, 300)) buffered = io.BytesIO()
buffered = io.BytesIO() image.save(buffered, format="JPEG")
image.save(buffered, format="JPEG") buffered.seek(0)
buffered.seek(0) image_bytes = buffered.getvalue()
image_bytes = buffered.getvalue() img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
if params['manage_VRAM']:
give_VRAM_priority('LLM')
return visible_result return visible_result
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
global pic_id, picture_response, streaming_state
global picture_response, params
if not picture_response: if not picture_response:
return string return string
@@ -180,18 +125,17 @@ def output_modifier(string):
if string == '': if string == '':
string = 'no viable description in reply, try regenerating' string = 'no viable description in reply, try regenerating'
return string
text = "" # I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
if (params['mode'] < 2): text = f'*Description: "{string}"*'
toggle_generation(False)
text = f'*Sends a picture which portrays: “{string}”*'
else:
text = string
string = get_SD_pictures(string) + "\n" + text image = get_SD_pictures(string)
return string picture_response = False
shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state
return image + "\n" + text
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
@@ -204,91 +148,42 @@ def bot_prefix_modifier(string):
return string return string
def toggle_generation(*args): def force_pic():
global picture_response, shared, streaming_state global picture_response
picture_response = True
if not args:
picture_response = not picture_response
else:
picture_response = args[0]
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
def filter_address(address):
address = address.strip()
# address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
address = re.sub('\/$', '', address) # remove trailing /s
if not address.startswith('http'):
address = 'http://' + address
return address
def SD_api_address_update(address):
global params
msg = "✔️ SD API is found on:"
address = filter_address(address)
params.update({"address": address})
try:
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
response.raise_for_status()
# r = response.json()
except:
msg = "❌ No SD API endpoint on:"
return gr.Textbox.update(label=msg)
def ui(): def ui():
# Gradio elements # Gradio elements
# gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title with gr.Accordion("Stable Diffusion api integration", open=True):
with gr.Accordion("Parameters", open=True):
with gr.Row(): with gr.Row():
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address') with gr.Column():
mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index") enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
with gr.Column(scale=1, min_width=300): save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM') with gr.Column():
save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat') address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
force_pic = gr.Button("Force the picture response") with gr.Row():
suppr_pic = gr.Button("Suppress the picture response") force_btn = gr.Button("Force the next response to be a picture")
generate_now_btn = gr.Button("Generate an image response to the input")
with gr.Accordion("Generation parameters", open=False): with gr.Accordion("Generation parameters", open=False):
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
with gr.Row(): with gr.Row():
with gr.Column(): negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler') # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
with gr.Column():
width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
with gr.Row():
steps = gr.Number(label="Steps:", value=params['steps'])
seed = gr.Number(label="Seed:", value=params['seed'])
cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
address.change(lambda x: params.update({"address": filter_address(x)}), address, None) enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
mode.select(lambda x: params.update({"mode": x}), mode, None)
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
save_img.change(lambda x: params.update({"save_img": x}), save_img, None) save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
address.change(lambda x: params.update({"address": x}), address, None)
address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None) prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None) negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
width.change(lambda x: params.update({"width": x}), width, None) dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
height.change(lambda x: params.update({"height": x}), height, None) # model.change(lambda x: params.update({"SD_model": x}), model, None)
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None) force_btn.click(force_pic)
steps.change(lambda x: params.update({"steps": x}), steps, None) generate_now_btn.click(force_pic)
seed.change(lambda x: params.update({"seed": x}), seed, None) generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)

View File

@@ -25,7 +25,7 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: {caption_image(picture)}*' text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
picture.thumbnail((300, 300)) picture.thumbnail((300, 300))
buffer = BytesIO() buffer = BytesIO()

View File

@@ -1,5 +1,6 @@
ipython ipython
num2words
omegaconf omegaconf
pydub pydub
PyYAML PyYAML
torch
torchaudio

View File

@@ -1,16 +1,14 @@
import re
import time import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch import torch
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.html_generator import chat_html_wrapper
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
params = { params = {
'activate': True, 'activate': True,
'speaker': 'en_56', 'speaker': 'en_56',
@@ -22,14 +20,13 @@ params = {
'autoplay': True, 'autoplay': True,
'voice_pitch': 'medium', 'voice_pitch': 'medium',
'voice_speed': 'medium', 'voice_speed': 'medium',
'local_cache_path': '' # User can override the default cache path to something other via settings.json
} }
current_params = params.copy() current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
streaming_state = shared.args.no_stream # remember if chat streaming was enabled streaming_state = shared.args.no_stream # remember if chat streaming was enabled
# Used for making text xml compatible, needed for voice pitch and speed control # Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({ table = str.maketrans({
@@ -40,31 +37,26 @@ table = str.maketrans({
'"': "&quot;", '"': "&quot;",
}) })
def xmlesc(txt): def xmlesc(txt):
return txt.translate(table) return txt.translate(table)
def load_model(): def load_model():
torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path'] model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
if Path(model_path).is_file():
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
else:
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device']) model.to(params['device'])
return model return model
model = load_model()
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
def remove_tts_from_history(name1, name2, mode): def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def toggle_text_in_history(name1, name2):
def toggle_text_in_history(name1, name2, mode):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
if visible_reply.startswith('<audio'): if visible_reply.startswith('<audio'):
@@ -73,8 +65,7 @@ def toggle_text_in_history(name1, name2, mode):
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"] shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else: else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"] shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
def input_modifier(string): def input_modifier(string):
""" """
@@ -84,13 +75,12 @@ def input_modifier(string):
# Remove autoplay from the last reply # Remove autoplay from the last reply
if shared.is_chat() and len(shared.history['internal']) > 0: if shared.is_chat() and len(shared.history['internal']) > 0:
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')] shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
shared.processing_message = "*Is recording a voice message...*" shared.processing_message = "*Is recording a voice message...*"
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -104,11 +94,15 @@ def output_modifier(string):
current_params = params.copy() current_params = params.copy()
break break
if not params['activate']: if params['activate'] == False:
return string return string
original_string = string original_string = string
string = tts_preprocessor.preprocess(string) string = remove_surrounded_chars(string)
string = string.replace('"', '')
string = string.replace('', '')
string = string.replace('\n', ' ')
string = string.strip()
if string == '': if string == '':
string = '*Empty reply, try regenerating*' string = '*Empty reply, try regenerating*'
@@ -124,10 +118,9 @@ def output_modifier(string):
string += f'\n\n{original_string}' string += f'\n\n{original_string}'
shared.processing_message = "*Is typing...*" shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state # restore the streaming option to the previous value shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -137,25 +130,17 @@ def bot_prefix_modifier(string):
return string return string
def setup():
global model
model = load_model()
def ui(): def ui():
# Gradio elements # Gradio elements
with gr.Accordion("Silero TTS"): with gr.Accordion("Silero TTS"):
with gr.Row(): with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS') activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically') autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row(): with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch') v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed') v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row(): with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts') convert = gr.Button('Permanently replace audios with the message texts')
convert_cancel = gr.Button('Cancel', visible=False) convert_cancel = gr.Button('Cancel', visible=False)
@@ -163,20 +148,20 @@ def ui():
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], shared.gradio['display']) convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history # Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None) show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], shared.gradio['display']) show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None) autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None) voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None) v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None) v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

View File

@@ -1,81 +0,0 @@
import time
from pathlib import Path
import torch
import tts_preprocessor
torch._C._jit_set_profiling_mode(False)
params = {
'activate': True,
'speaker': 'en_49',
'language': 'en',
'model_id': 'v3_en',
'sample_rate': 48000,
'device': 'cpu',
'show_text': True,
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
}
current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
})
def xmlesc(txt):
return txt.translate(table)
def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device'])
return model
model = load_model()
def output_modifier(string):
"""
This function is applied to the model outputs.
"""
global model, current_params
original_string = string
string = tts_preprocessor.preprocess(string)
processed_string = string
if string == '':
string = '*Empty reply, try regenerating*'
else:
output_file = Path(f'extensions/silero_tts/outputs/test_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay = 'autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
if params['show_text']:
string += f'\n\n{original_string}\n\nProcessed:\n{processed_string}'
print(string)
if __name__ == '__main__':
import sys
output_modifier(sys.argv[1])

View File

@@ -1,194 +0,0 @@
import re
from num2words import num2words
punctuation = r'[\s,.?!/)\'\]>]'
alphabet_map = {
"A": " Ei ",
"B": " Bee ",
"C": " See ",
"D": " Dee ",
"E": " Eee ",
"F": " Eff ",
"G": " Jee ",
"H": " Eich ",
"I": " Eye ",
"J": " Jay ",
"K": " Kay ",
"L": " El ",
"M": " Emm ",
"N": " Enn ",
"O": " Ohh ",
"P": " Pee ",
"Q": " Queue ",
"R": " Are ",
"S": " Ess ",
"T": " Tee ",
"U": " You ",
"V": " Vee ",
"W": " Double You ",
"X": " Ex ",
"Y": " Why ",
"Z": " Zed " # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
}
def preprocess(string):
# the order for some of these matter
# For example, you need to remove the commas in numbers before expanding them
string = remove_surrounded_chars(string)
string = string.replace('"', '')
string = string.replace('\u201D', '').replace('\u201C', '') # right and left quote
string = string.replace('\u201F', '') # italic looking quote
string = string.replace('\n', ' ')
string = convert_num_locale(string)
string = replace_negative(string)
string = replace_roman(string)
string = hyphen_range_to(string)
string = num_to_words(string)
# TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
# try to say the abbreviation or spell it out as I've done below is not agreed upon
# For now, expand abbreviations to pronunciations
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation
string = replace_abbreviations(string)
string = replace_lowercase_abbreviations(string)
# cleanup whitespaces
# remove whitespace before punctuation
string = re.sub(rf'\s+({punctuation})', r'\1', string)
string = string.strip()
# compact whitespace
string = ' '.join(string.split())
return string
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub(r'\*[^*]*?(\*|$)', '', string)
def convert_num_locale(text):
# This detects locale and converts it to American without comma separators
pattern = re.compile(r'(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)')
result = text
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + result[start:end].replace('.', '').replace(',', '.') + result[end:len(result)]
# removes comma separators from existing American numbers
pattern = re.compile(r'(\d),(\d)')
result = pattern.sub(r'\1\2', result)
return result
def replace_negative(string):
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five
return re.sub(rf'(\s)(-)(\d+)({punctuation})', r'\1negative \3\4', string)
def replace_roman(string):
# find a string of roman numerals.
# Only 2 or more, to avoid capturing I and single character abbreviations, like names
pattern = re.compile(rf'\s[IVXLCDM]{{2,}}{punctuation}')
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start + 1] + str(roman_to_int(result[start + 1:end - 1])) + result[end - 1:len(result)]
return result
def roman_to_int(s):
rom_val = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
int_val = 0
for i in range(len(s)):
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]:
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]]
else:
int_val += rom_val[s[i]]
return int_val
def hyphen_range_to(text):
pattern = re.compile(r'(\d+)[-](\d+)')
result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
return result
def num_to_words(text):
# 1000 or 10.23
pattern = re.compile(r'\d+\.\d+|\d+')
result = pattern.sub(lambda x: num2words(float(x.group())), text)
return result
def replace_abbreviations(string):
# abbreviations 1 to 4 characters long. It will get things like A and I, but those are pronounced with their letter
pattern = re.compile(rf'(^|[\s(.\'\[<])([A-Z]{{1,4}})({punctuation}|$)')
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
return result
def replace_lowercase_abbreviations(string):
# abbreviations 1 to 4 characters long, separated by dots i.e. e.g.
pattern = re.compile(rf'(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({punctuation}|$)')
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + replace_abbreviation(result[start:end].upper()) + result[end:len(result)]
return result
def replace_abbreviation(string):
result = ""
for char in string:
result += match_mapping(char)
return result
def match_mapping(char):
for mapping in alphabet_map.keys():
if char == mapping:
return alphabet_map[char]
return char
def __main__(args):
print(preprocess(args[1]))
if __name__ == "__main__":
import sys
__main__(sys.argv)

View File

@@ -1,6 +1,5 @@
import gradio as gr import gradio as gr
import speech_recognition as sr import speech_recognition as sr
from modules import shared
input_hijack = { input_hijack = {
'state': False, 'state': False,
@@ -8,7 +7,7 @@ input_hijack = {
} }
def do_stt(audio): def do_stt(audio, text_state=""):
transcription = "" transcription = ""
r = sr.Recognizer() r = sr.Recognizer()
@@ -22,23 +21,34 @@ def do_stt(audio):
except sr.RequestError as e: except sr.RequestError as e:
print("Could not request results from Whisper", e) print("Could not request results from Whisper", e)
return transcription input_hijack.update({"state": True, "value": [transcription, transcription]})
text_state += transcription + " "
return text_state, text_state
def auto_transcribe(audio, auto_submit): def update_hijack(val):
input_hijack.update({"state": True, "value": [val, val]})
return val
def auto_transcribe(audio, audio_auto, text_state=""):
if audio is None: if audio is None:
return "", "" return "", ""
if audio_auto:
transcription = do_stt(audio) return do_stt(audio, text_state)
if auto_submit: return "", ""
input_hijack.update({"state": True, "value": [transcription, transcription]})
return transcription, None
def ui(): def ui():
tr_state = gr.State(value="")
output_transcription = gr.Textbox(label="STT-Input",
placeholder="Speech Preview. Click \"Generate\" to send",
interactive=True)
output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
with gr.Row(): with gr.Row():
audio = gr.Audio(source="microphone") audio = gr.Audio(source="microphone")
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=True) audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
audio.change(fn=auto_transcribe, inputs=[audio, auto_submit], outputs=[shared.gradio['textbox'], audio]) transcribe_button = gr.Button(value="Transcribe")
audio.change(None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}") transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])

View File

@@ -100,10 +100,10 @@ def load_quantized(model_name):
found_safetensors = list(path_to_model.glob("*.safetensors")) found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None pt_path = None
if len(found_pts) > 0: if len(found_pts) == 1:
pt_path = found_pts[-1] pt_path = found_pts[0]
elif len(found_safetensors) > 0: elif len(found_safetensors) == 1:
pt_path = found_safetensors[-1] pt_path = found_safetensors[0]
else: else:
if path_to_model.name.lower().startswith('llama-7b'): if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.wbits}bit' pt_model = f'llama-7b-{shared.args.wbits}bit'
@@ -119,14 +119,13 @@ def load_quantized(model_name):
# Try to find the .safetensors or .pt both in the model dir and in the subfolder # Try to find the .safetensors or .pt both in the model dir and in the subfolder
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists(): if path.exists():
print(f"Found {path}")
pt_path = path pt_path = path
break break
if not pt_path: if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...") print("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit() exit()
else:
print(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload # qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer: if model_type == 'llama' and shared.args.pre_layer:

View File

@@ -4,7 +4,14 @@ import torch
from peft import PeftModel from peft import PeftModel
import modules.shared as shared import modules.shared as shared
from modules.models import reload_model from modules.models import load_model
from modules.text_generation import clear_torch_cache
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):

View File

@@ -12,59 +12,53 @@ from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
import modules.shared as shared import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import (chat_html_wrapper, fix_newlines, from modules.html_generator import (fix_newlines, chat_html_wrapper,
make_thumbnail) make_thumbnail)
from modules.text_generation import (encode, generate_reply, from modules.text_generation import (encode, generate_reply,
get_max_prompt_length) get_max_prompt_length)
def generate_chat_prompt(user_input, state, **kwargs): def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
is_instruct = state['mode'] == 'instruct' rows = [f"{context.strip()}\n"]
rows = [f"{state['context'].strip()}\n"]
# Finding the maximum prompt size # Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size) max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
if is_instruct: if is_instruct:
prefix1 = f"{state['name1']}\n" prefix1 = f"{name1}\n"
prefix2 = f"{state['name2']}\n" prefix2 = f"{name2}\n"
else: else:
prefix1 = f"{state['name1']}: " prefix1 = f"{name1}: "
prefix2 = f"{state['name2']}: " prefix2 = f"{name2}: "
i = len(shared.history['internal']) - 1 i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1: rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
string = shared.history['internal'][i][0] string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n") rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
i -= 1 i -= 1
if impersonate: if impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2 limit = 2
elif _continue:
limit = 3
else: else:
# Adding the user message # Adding the user message
user_input = fix_newlines(user_input) user_input = fix_newlines(user_input)
if len(user_input) > 0: if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n") rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
# Adding the Character prefix # Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3 limit = 3
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length: while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)
@@ -74,26 +68,16 @@ def generate_chat_prompt(user_input, state, **kwargs):
return prompt return prompt
def get_stopping_strings(state): def extract_message_from_reply(reply, name1, name2, stop_at_newline):
if state['mode'] == 'instruct':
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else:
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
stopping_strings += state['custom_stopping_strings']
return stopping_strings
def extract_message_from_reply(reply, state):
next_character_found = False next_character_found = False
stopping_strings = get_stopping_strings(state)
if state['stop_at_newline']: if stop_at_newline:
lines = reply.split('\n') lines = reply.split('\n')
reply = lines[0].strip() reply = lines[0].strip()
if len(lines) > 1: if len(lines) > 1:
next_character_found = True next_character_found = True
else: else:
for string in stopping_strings: for string in [f"\n{name1}:", f"\n{name2}:"]:
idx = reply.find(string) idx = reply.find(string)
if idx != -1: if idx != -1:
reply = reply[:idx] reply = reply[:idx]
@@ -102,7 +86,7 @@ def extract_message_from_reply(reply, state):
# If something like "\nYo" is generated just before "\nYou:" # If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it # is completed, trim it
if not next_character_found: if not next_character_found:
for string in stopping_strings: for string in [f"\n{name1}:", f"\n{name2}:"]:
for j in range(len(string) - 1, 0, -1): for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]: if reply[-j:] == string[:j]:
reply = reply[:-j] reply = reply[:-j]
@@ -115,17 +99,20 @@ def extract_message_from_reply(reply, state):
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, state, regenerate=False, _continue=False): def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
# Defining some variables eos_token = '\n' if generate_state['stop_at_newline'] else None
cumulative_reply = '' name1_original = name1
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None if 'pygmalion' in shared.model_name.lower():
just_started = True name1 = "You"
visible_text = custom_generate_chat_prompt = None
eos_token = '\n' if state['stop_at_newline'] else None
stopping_strings = get_stopping_strings(state)
# Check if any extension wants to hijack this function call # Check if any extension wants to hijack this function call
visible_text = None
custom_generate_chat_prompt = None
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False extension.input_hijack['state'] = False
@@ -135,29 +122,29 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
if not _continue: text = apply_extensions(text, "input")
text = apply_extensions(text, "input")
# Generating the prompt kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
kwargs = {'_continue': _continue}
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs) prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
else: else:
prompt = custom_generate_chat_prompt(text, state, **kwargs) prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
# Yield *Is typing...* # Yield *Is typing...*
if not any((regenerate, _continue)): if not regenerate:
yield shared.history['visible'] + [[visible_text, shared.processing_message]] yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate # Generate
for i in range(state['chat_generation_attempts']): cumulative_reply = ''
just_started = True
for i in range(generate_state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, state) reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
@@ -166,17 +153,11 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
return shared.history['visible'] return shared.history['visible']
if just_started: if just_started:
just_started = False just_started = False
if not _continue: shared.history['internal'].append(['', ''])
shared.history['internal'].append(['', '']) shared.history['visible'].append(['', ''])
shared.history['visible'].append(['', ''])
if _continue: shared.history['internal'][-1] = [text, reply]
sep = list(map(lambda x: ' ' if len(x) > 0 and x[-1] != ' ' else '', last_reply)) shared.history['visible'][-1] = [visible_text, visible_reply]
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
else:
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
if not shared.args.no_stream: if not shared.args.no_stream:
yield shared.history['visible'] yield shared.history['visible']
if next_character_found: if next_character_found:
@@ -188,22 +169,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, state): def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
# Defining some variables eos_token = '\n' if generate_state['stop_at_newline'] else None
cumulative_reply = '' if 'pygmalion' in shared.model_name.lower():
eos_token = '\n' if state['stop_at_newline'] else None name1 = "You"
prompt = generate_chat_prompt(text, state, impersonate=True)
stopping_strings = get_stopping_strings(state) prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
for i in range(state['chat_generation_attempts']): cumulative_reply = ''
for i in range(generate_state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, state) reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
yield reply yield reply
if next_character_found: if next_character_found:
break break
@@ -214,32 +200,22 @@ def impersonate_wrapper(text, state):
yield reply yield reply
def cai_chatbot_wrapper(text, state): def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, state): for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(history, name1, name2, mode)
def regenerate_wrapper(text, state): def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
else: else:
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*' # Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
for history in chatbot_wrapper(last_internal[0], state, regenerate=True): for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def continue_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else:
# Yield ' ...'
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):
@@ -267,21 +243,6 @@ def replace_last_reply(text, name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def send_dummy_message(text, name1, name2, mode):
shared.history['visible'].append([text, ''])
shared.history['internal'].append([apply_extensions(text, "input"), ''])
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def send_dummy_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '':
shared.history['visible'].append(['', ''])
shared.history['internal'].append(['', ''])
shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input")
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def clear_html(): def clear_html():
return chat_html_wrapper([], "", "") return chat_html_wrapper([], "", "")
@@ -294,9 +255,6 @@ def clear_chat_log(name1, name2, greeting, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
# Save cleared logs
save_history(mode)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@@ -306,7 +264,7 @@ def redraw_html(name1, name2, mode):
def tokenize_dialogue(dialogue, name1, name2, mode): def tokenize_dialogue(dialogue, name1, name2, mode):
history = [] history = []
messages = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('<start>', '', dialogue) dialogue = re.sub('<start>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
@@ -315,6 +273,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
if len(idx) == 0: if len(idx) == 0:
return history return history
messages = []
for i in range(len(idx) - 1): for i in range(len(idx) - 1):
messages.append(dialogue[idx[i]:idx[i + 1]].strip()) messages.append(dialogue[idx[i]:idx[i + 1]].strip())
messages.append(dialogue[idx[-1]:].strip()) messages.append(dialogue[idx[-1]:].strip())
@@ -341,23 +300,15 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
return history return history
def save_history(mode, timestamp=False): def save_history(timestamp=True):
# Instruct mode histories should not be saved as if if timestamp:
# Alpaca or Vicuna were characters fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
if mode == 'instruct':
if not timestamp:
return
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
if timestamp: fname = f"{shared.character}_persistent.json"
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
fname = f"{shared.character}_persistent.json"
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
@@ -371,6 +322,16 @@ def load_history(file, name1, name2):
shared.history['visible'] = j['data_visible'] shared.history['visible'] = j['data_visible']
else: else:
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
# Compatibility with Pygmalion AI's official web UI
elif 'chat' in j:
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
shared.history['visible'][0][0] = ''
else:
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
except: except:
shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
@@ -406,6 +367,8 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, mode): def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
shared.history['internal'] = []
shared.history['visible'] = []
context = greeting = end_of_turn = "" context = greeting = end_of_turn = ""
greeting_field = 'greeting' greeting_field = 'greeting'
picture = None picture = None
@@ -450,22 +413,13 @@ def load_character(character, name1, name2, mode):
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn'] end_of_turn = shared.settings['end_of_turn']
if mode != 'instruct': if Path(f'logs/{shared.character}_persistent.json').exists():
shared.history['internal'] = [] load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
shared.history['visible'] = [] elif greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
if Path(f'logs/{shared.character}_persistent.json').exists(): return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
else:
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
# Create .json log files since they don't already exist
save_history(mode)
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def load_default_history(name1, name2): def load_default_history(name1, name2):
@@ -513,4 +467,4 @@ def upload_your_profile_picture(img, name1, name2, mode):
img.save(Path('cache/pfp_me.png')) img.save(Path('cache/pfp_me.png'))
print('Profile picture saved to "cache/pfp_me.png"') print('Profile picture saved to "cache/pfp_me.png"')
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)

View File

@@ -11,31 +11,29 @@ setup_called = set()
def load_extensions(): def load_extensions():
global state, setup_called global state
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
print(f'Loading the extension "{name}"... ', end='') print(f'Loading the extension "{name}"... ', end='')
try: try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
extension = eval(f"extensions.{name}.script")
if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension)
extension.setup()
state[name] = [True, i] state[name] = [True, i]
print('Ok.') print('Ok.')
except: except:
print('Fail.') print('Fail.')
traceback.print_exc() traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
for name in sorted(state, key=lambda x: state[x][1]): for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0]: if state[name][0] == True:
yield eval(f"extensions.{name}.script"), name yield eval(f"extensions.{name}.script"), name
# Extension functions that map string -> string # Extension functions that map string -> string
def apply_extensions(text, typ): def apply_extensions(text, typ):
for extension, _ in iterator(): for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"): if typ == "input" and hasattr(extension, "input_modifier"):
@@ -59,9 +57,14 @@ def create_extensions_block():
extension.params[param] = shared.settings[_id] extension.params[param] = shared.settings[_id]
should_display_ui = False should_display_ui = False
# Running setup function
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, "ui"): if hasattr(extension, "ui"):
should_display_ui = True should_display_ui = True
if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension)
extension.setup()
# Creating the extension ui elements # Creating the extension ui elements
if should_display_ui: if should_display_ui:

View File

@@ -164,9 +164,10 @@ def generate_instruct_html(history):
def generate_cai_chat_html(history, name1, name2, reset_cache=False): def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
# We use ?name2 and ?time.time() to force the browser to reset caches # The time.time() is to prevent the brower from caching the image
img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else '' suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
img_me = f'<img src="file/cache/pfp_me.png?{time.time() if reset_cache else ""}">' if Path("cache/pfp_me.png").exists() else '' img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
for i, _row in enumerate(history[::-1]): for i, _row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row] row = [convert_to_markdown(entry) for entry in _row]

View File

@@ -1,176 +0,0 @@
import math
import sys
import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama
from typing import Optional
from typing import Tuple
import modules.shared as shared
if shared.args.xformers:
try:
import xformers.ops
except Exception:
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention():
if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
print("Replaced attention with xformers_attention")
elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
print("Replaced attention with sdp_attention")
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
#We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
dtype = query_states.dtype
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
#We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

View File

@@ -1,4 +1,3 @@
import gc
import json import json
import os import os
import re import re
@@ -14,14 +13,14 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer) BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared import modules.shared as shared
from modules import llama_attn_hijack
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen: if shared.args.flexgen:
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
local_rank = None
if shared.args.deepspeed: if shared.args.deepspeed:
import deepspeed import deepspeed
from transformers.deepspeed import (HfDeepSpeedConfig, from transformers.deepspeed import (HfDeepSpeedConfig,
@@ -170,46 +169,19 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Hijack attention with xformers
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
# Loading the tokenizer # Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM: elif type(model) is transformers.LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
# Leaving this here until the LLaMA tokenizer gets figured out.
# For some people this fixes things, for others it causes an error.
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left'
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def reload_model():
unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name)
def load_soft_prompt(name): def load_soft_prompt(name):
if name == 'None': if name == 'None':
shared.soft_prompt = False shared.soft_prompt = False

View File

@@ -34,13 +34,7 @@ settings = {
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': 'Hello there!', 'greeting': 'Hello there!',
'end_of_turn': '', 'end_of_turn': '',
'custom_stopping_strings': '',
'stop_at_newline': False, 'stop_at_newline': False,
'add_bos_token': True,
'ban_eos_token': False,
'truncation_length': 2048,
'truncation_length_min': 0,
'truncation_length_max': 4096,
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048, 'chat_prompt_size_max': 2048,
@@ -50,7 +44,7 @@ settings = {
'default_extensions': [], 'default_extensions': [],
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {
'default': 'Default', 'default': 'NovelAI-Sphinx Moth',
'.*(alpaca|llama)': "LLaMA-Precise", '.*(alpaca|llama)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter', '.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive', '.*RWKV': 'Naive',
@@ -95,7 +89,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.') parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.') parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
@@ -104,8 +98,6 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.') parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
# llama.cpp # llama.cpp
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.') parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')

View File

@@ -1,4 +1,4 @@
import random import gc
import re import re
import time import time
import traceback import traceback
@@ -12,48 +12,38 @@ from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria) _SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import clear_torch_cache, local_rank from modules.models import local_rank
def get_max_prompt_length(state): def get_max_prompt_length(tokens):
max_length = state['truncation_length'] - state['max_new_tokens'] max_length = 2048 - tokens
if shared.soft_prompt: if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1] max_length -= shared.soft_prompt_tensor.shape[1]
return max_length return max_length
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids)) input_ids = np.array(input_ids).reshape(1, len(input_ids))
return input_ids return input_ids
else: else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
# This is a hack for making replies more creative.
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Llama adds this extra token when the first character is '\n', and this
# compromises the stopping criteria, so we just remove it
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:] input_ids = input_ids[:, 1:]
# Handling truncation if shared.args.cpu:
if truncation_length is not None: return input_ids
input_ids = input_ids[:, -truncation_length:] elif shared.args.flexgen:
return input_ids.numpy()
if any((shared.is_RWKV, shared.is_llamacpp, shared.args.cpu)): elif shared.args.deepspeed:
return input_ids return input_ids.to(device=local_rank)
elif shared.args.flexgen: elif torch.has_mps:
return input_ids.numpy() device = torch.device('mps')
elif shared.args.deepspeed: return input_ids.to(device)
return input_ids.to(device=local_rank) else:
elif torch.has_mps: return input_ids.cuda()
device = torch.device('mps')
return input_ids.to(device)
else:
return input_ids.cuda()
def decode(output_ids): def decode(output_ids):
@@ -73,8 +63,9 @@ def generate_softprompt_input_tensors(input_ids):
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@@ -82,8 +73,9 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s) s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s return s
# Fix the LaTeX equations in galactica # Fix the LaTeX equations in galactica
def fix_galactica(s): def fix_galactica(s):
s = s.replace(r'\[', r'$') s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$') s = s.replace(r'\]', r'$')
@@ -109,47 +101,48 @@ def formatted_outputs(reply, model_name):
return reply return reply
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
def set_manual_seed(seed): def set_manual_seed(seed):
seed = int(seed) if seed != -1:
if seed == -1: torch.manual_seed(seed)
seed = random.randint(1, 2**31) if torch.cuda.is_available():
torch.manual_seed(seed) torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
return seed
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, state, eos_token=None, stopping_strings=[]): def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
seed = set_manual_seed(state['seed']) set_manual_seed(generate_state['seed'])
shared.stop_everything = False shared.stop_everything = False
generate_params = {} generate_params = {}
t0 = time.time() t0 = time.time()
original_question = question original_question = question
if not shared.is_chat(): if not shared.is_chat():
question = apply_extensions(question, 'input') question = apply_extensions(question, "input")
if shared.args.verbose:
print(f"\n\n{question}\n--------------------\n")
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
if shared.args.verbose:
print(f'\n\n{question}\n--------------------\n')
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k] generate_params[k] = generate_state[k]
generate_params['token_count'] = state['max_new_tokens'] generate_params["token_count"] = generate_state["max_new_tokens"]
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply output = original_question + reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
else: else:
if not shared.is_chat(): if not shared.is_chat():
@@ -160,7 +153,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
for reply in shared.model.generate_with_streaming(context=question, **generate_params): for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply output = original_question + reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
except Exception: except Exception:
@@ -169,53 +162,47 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
t1 = time.time() t1 = time.time()
original_tokens = len(encode(original_question)[0]) original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return return
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) input_ids = encode(question, generate_state['max_new_tokens'])
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
if shared.args.verbose:
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None: if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1])) eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Handling the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList() stopping_criteria_list = transformers.StoppingCriteriaList()
for st in [stopping_strings, state['custom_stopping_strings']]: if type(stopping_strings) is list and len(stopping_strings) > 0:
if type(st) is list and len(st) > 0: t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
break
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
if not shared.args.flexgen: if not shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
generate_params[k] = state[k] generate_params[k] = generate_state[k]
generate_params['eos_token_id'] = eos_token_ids generate_params["eos_token_id"] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list generate_params["stopping_criteria"] = stopping_criteria_list
if state['ban_eos_token']: if shared.args.no_stream:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] generate_params["min_length"] = 0
else: else:
for k in ['max_new_tokens', 'do_sample', 'temperature']: for k in ["do_sample", "temperature"]:
generate_params[k] = state[k] generate_params[k] = generate_state[k]
generate_params['stop'] = state['eos_token_ids'][-1] generate_params["stop"] = generate_state["eos_token_ids"][-1]
if not shared.args.no_stream: if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8 generate_params["max_new_tokens"] = 8
if shared.args.no_cache: if shared.args.no_cache:
generate_params.update({'use_cache': False}) generate_params.update({"use_cache": False})
if shared.args.deepspeed: if shared.args.deepspeed:
generate_params.update({'synced_gpus': True}) generate_params.update({"synced_gpus": True})
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.update({'inputs_embeds': inputs_embeds}) generate_params.update({"inputs_embeds": inputs_embeds})
generate_params.update({'inputs': filler_input_ids}) generate_params.update({"inputs": filler_input_ids})
else: else:
generate_params.update({'inputs': input_ids}) generate_params.update({"inputs": input_ids})
try: try:
# Generate the entire reply at once. # Generate the entire reply at once.
@@ -230,7 +217,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0]) new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:]) reply = decode(output[-new_tokens:])
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions(reply, "output")
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
@@ -257,7 +244,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0]) new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:]) reply = decode(output[-new_tokens:])
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions(reply, "output")
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break
@@ -265,7 +252,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
for i in range(state['max_new_tokens'] // 8 + 1): for i in range(generate_state['max_new_tokens'] // 8 + 1):
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = shared.model.generate(**generate_params)[0] output = shared.model.generate(**generate_params)[0]
@@ -275,7 +262,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(original_input_ids[0]) new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:]) reply = decode(output[-new_tokens:])
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions(reply, "output")
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break break
@@ -284,10 +271,10 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
input_ids = np.reshape(output, (1, output.shape[0])) input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.update({'inputs_embeds': inputs_embeds}) generate_params.update({"inputs_embeds": inputs_embeds})
generate_params.update({'inputs': filler_input_ids}) generate_params.update({"inputs": filler_input_ids})
else: else:
generate_params.update({'inputs': input_ids}) generate_params.update({"inputs": input_ids})
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
@@ -297,5 +284,5 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
t1 = time.time() t1 = time.time()
original_tokens = len(original_input_ids[0]) original_tokens = len(original_input_ids[0])
new_tokens = len(output) - original_tokens new_tokens = len(output) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return return

View File

@@ -152,7 +152,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...") print("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file: with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
raw_text = file.read() raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text) tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
@@ -238,7 +238,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
warmup_steps=100, warmup_steps=100,
num_train_epochs=epochs, num_train_epochs=epochs,
learning_rate=actual_lr, learning_rate=actual_lr,
fp16=False if shared.args.cpu else True, fp16=True,
logging_steps=20, logging_steps=20,
evaluation_strategy="steps" if eval_data is not None else "no", evaluation_strategy="steps" if eval_data is not None else "no",
save_strategy="steps", save_strategy="steps",
@@ -248,8 +248,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
save_total_limit=3, save_total_limit=3,
load_best_model_at_end=True if eval_data is not None else False, load_best_model_at_end=True if eval_data is not None else False,
# TODO: Enable multi-device support # TODO: Enable multi-device support
ddp_find_unused_parameters=None, ddp_find_unused_parameters=None
no_cuda=shared.args.cpu
), ),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()]) callbacks=list([Callbacks()])

View File

@@ -1,6 +1,7 @@
do_sample=True do_sample=True
top_p=0.95 top_p=0.5
top_k=50 top_k=40
temperature=1 temperature=0.7
repetition_penalty=1.2 repetition_penalty=1.2
typical_p=1.0 typical_p=1.0
early_stopping=False

View File

@@ -1,10 +1,10 @@
accelerate==0.18.0 accelerate==0.18.0
bitsandbytes==0.37.2
datasets datasets
flexgen==0.1.7 flexgen==0.1.7
gradio==3.24.1 gradio==3.24.1
markdown markdown
numpy numpy
Pillow>=9.5.0
peft==0.2.0 peft==0.2.0
requests requests
rwkv==0.7.3 rwkv==0.7.3
@@ -13,6 +13,3 @@ sentencepiece
pyyaml pyyaml
tqdm tqdm
git+https://github.com/huggingface/transformers git+https://github.com/huggingface/transformers
bitsandbytes==0.37.2; platform_system != "Windows"
llama-cpp-python==0.1.30; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"

320
server.py
View File

@@ -2,14 +2,11 @@ import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import importlib
import io import io
import json import json
import os
import re import re
import sys import sys
import time import time
import traceback
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -18,12 +15,12 @@ import gradio as gr
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import api, chat, shared, training, ui from modules import chat, shared, training, ui, api
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply, stop_everything_event from modules.text_generation import (clear_torch_cache, generate_reply,
stop_everything_event)
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@@ -82,6 +79,11 @@ def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
@@ -176,34 +178,6 @@ def create_prompt_menus():
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def download_model_wrapper(repo_id):
try:
downloader = importlib.import_module("download-model")
model = repo_id
branch = "main"
check = False
yield ("Cleaning up the model/branch names")
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
yield ("Getting the download links from Hugging Face")
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
yield ("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora)
if check:
yield ("Checking previously downloaded files")
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
yield (f"Downloading files to {output_folder}")
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
yield ("Done!")
except:
yield traceback.format_exc()
def create_model_menus(): def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -214,26 +188,16 @@ def create_model_menus():
with gr.Row(): with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
with gr.Column():
shared.gradio['download_button'] = gr.Button("Download")
shared.gradio['download_status'] = gr.Markdown()
with gr.Column():
pass
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -246,24 +210,24 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Box(): with gr.Box():
gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))') gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.') shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.') shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.') shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
with gr.Column(): with gr.Column():
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.') shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.') shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.') shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
with gr.Column(): with gr.Column():
with gr.Box(): with gr.Box():
gr.Markdown('Contrastive search') gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
with gr.Box():
gr.Markdown('Beam search (uses a lot of VRAM)') gr.Markdown('Beam search (uses a lot of VRAM)')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -272,13 +236,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Group():
with gr.Row():
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='This forces the model to never end the generation prematurely.')
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
with gr.Accordion('Soft prompt', open=False): with gr.Accordion('Soft prompt', open=False):
with gr.Row(): with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@@ -288,7 +245,7 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
@@ -361,21 +318,6 @@ else:
title = 'Text generation web UI' title = 'Text generation web UI'
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
return elements
def gather_interface_values(*args):
output = {}
for i, element in enumerate(shared.input_elements):
output[element] = args[i]
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
return output
def create_interface(): def create_interface():
gen_events = [] gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
@@ -383,34 +325,27 @@ def create_interface():
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.is_chat(): if shared.is_chat():
shared.input_elements = list_interface_input_elements(chat=True)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
with gr.Row(): with gr.Row():
shared.gradio['Regenerate'] = gr.Button('Regenerate')
shared.gradio['Continue'] = gr.Button('Continue')
shared.gradio['Impersonate'] = gr.Button('Impersonate') shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate')
with gr.Row(): with gr.Row():
shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
shared.gradio['Copy last reply'] = gr.Button('Copy last reply') shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
with gr.Row(): shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
shared.gradio['Remove last'] = gr.Button('Remove last')
shared.gradio['Clear history'] = gr.Button('Clear history') shared.gradio['Clear history'] = gr.Button('Clear history')
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio['Remove last'] = gr.Button('Remove last')
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode") shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.") shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
with gr.Row(): with gr.Row():
@@ -457,102 +392,66 @@ def create_interface():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column(): with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character') shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']] shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
def set_chat_input(textbox):
return textbox, ""
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Regenerate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Impersonate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
)
shared.gradio['Replace last reply'].click(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Send dummy message'].click(
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Send dummy reply'].click(
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Clear history-confirm'].click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Stop'].click(
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['mode'].change(
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Instruction templates'].change(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
# Clearing stuff and saving the history
for i in ['Generate', 'Regenerate', 'Replace last reply']:
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display']) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None) shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True) shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
@@ -580,27 +479,14 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else: else:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -626,28 +512,12 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then( shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Model", elem_id="model-tab"): with gr.Tab("Model", elem_id="model-tab"):
@@ -671,16 +541,26 @@ def create_interface():
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode") shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions") shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags") shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface") shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
# Reset interface event shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
shared.gradio['reset_interface'].click( shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat(): if not shared.is_chat():
api.create_apis() api.create_apis()

View File

@@ -7,14 +7,7 @@
"name2": "Assistant", "name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"greeting": "Hello there!", "greeting": "Hello there!",
"end_of_turn": "",
"custom_stopping_strings": "",
"stop_at_newline": false, "stop_at_newline": false,
"add_bos_token": true,
"ban_eos_token": true,
"truncation_length": 2048,
"truncation_length_min": 0,
"truncation_length_max": 4096,
"chat_prompt_size": 2048, "chat_prompt_size": 2048,
"chat_prompt_size_min": 0, "chat_prompt_size_min": 0,
"chat_prompt_size_max": 2048, "chat_prompt_size_max": 2048,
@@ -26,8 +19,7 @@
"gallery" "gallery"
], ],
"presets": { "presets": {
"default": "Default", "default": "NovelAI-Sphinx Moth",
".*(alpaca|llama)": "LLaMA-Precise",
".*pygmalion": "NovelAI-Storywriter", ".*pygmalion": "NovelAI-Storywriter",
".*RWKV": "Naive" ".*RWKV": "Naive"
}, },