38 Commits

Author SHA1 Message Date
oobabooga
f57ebb6c42 Revert unintended changes 2023-04-07 00:13:14 -03:00
oobabooga
01cacfc14f Reformat everything 2023-04-07 00:08:46 -03:00
oobabooga
848c4edfd5 Update README.md 2023-04-06 22:52:35 -03:00
oobabooga
e047cd1def Update README 2023-04-06 22:50:58 -03:00
loeken
08b9d1b23a creating a layer with Docker/docker-compose (#633) 2023-04-06 22:46:04 -03:00
oobabooga
64bcde56ab Minor css change 2023-04-06 20:14:29 -03:00
oobabooga
58ed87e5d9 Update requirements.txt 2023-04-06 18:42:54 -03:00
dependabot[bot]
21be80242e Bump rwkv from 0.7.2 to 0.7.3 (#842) 2023-04-06 17:52:27 -03:00
OWKenobi
310bf46a94 Instruction Character Vicuna, Instruction Mode Bugfix (#838) 2023-04-06 17:40:44 -03:00
DavG25
20b8ca4482 Add CSS for lists (#833) 2023-04-06 16:15:04 -03:00
oobabooga
113f94b61e Bump transformers (16-bit llama must be reconverted/redownloaded) 2023-04-06 16:04:03 -03:00
oobabooga
5f4f38ca5d Merge branch 'main' of github.com:oobabooga/text-generation-webui 2023-04-06 14:38:29 -03:00
oobabooga
d9e7aba714 Update README.md 2023-04-06 13:42:24 -03:00
oobabooga
59058576b5 Remove unused requirement 2023-04-06 13:28:21 -03:00
oobabooga
eec3665845 Add instructions for updating requirements 2023-04-06 13:24:01 -03:00
oobabooga
03cb44fc8c Add new llama.cpp library (2048 context, temperature, etc now work) 2023-04-06 13:12:14 -03:00
EyeDeck
39f3fec913 Broaden GPTQ-for-LLaMA branch support (#820) 2023-04-06 12:16:48 -03:00
oobabooga
8cd899515e Change instruct html a bit 2023-04-06 12:00:20 -03:00
oobabooga
4a28f39823 Update README.md 2023-04-06 02:47:27 -03:00
oobabooga
158ec51ae3 Increase instruct mode padding 2023-04-06 02:20:52 -03:00
Alex "mcmonkey" Goodwin
0c7ef26981 Lora trainer improvements (#763) 2023-04-06 02:04:11 -03:00
oobabooga
5b301d9a02 Create a Model tab 2023-04-06 01:54:05 -03:00
oobabooga
4a400320dd Clean up 2023-04-06 01:47:00 -03:00
oobabooga
e94ab5dac1 Minor fixes 2023-04-06 01:43:10 -03:00
Randell Miller
641646a801 Fix crash if missing instructions directory (#812) 2023-04-06 01:24:22 -03:00
oobabooga
3f3e42e26c Refactor several function calls and the API 2023-04-06 01:22:15 -03:00
SDS
378d21e80c Add LLaMA-Precise preset (#767) 2023-04-05 18:52:36 -03:00
eiery
19b516b11b fix link to streaming api example (#803) 2023-04-05 14:50:23 -03:00
oobabooga
7617ed5bfd Add AMD instructions 2023-04-05 14:42:58 -03:00
oobabooga
770ef5744f Update README 2023-04-05 14:38:11 -03:00
Forkoz
8203ce0cac Stop character pic from being cached when changing chars or clearing. (#798)
Tested on both FF and chromium
2023-04-05 14:25:01 -03:00
oobabooga
7f66421369 Fix loading characters 2023-04-05 14:22:32 -03:00
oobabooga
90141bc1a8 Fix saving prompts on Windows 2023-04-05 14:08:54 -03:00
oobabooga
cf2c4e740b Disable gradio analytics globally 2023-04-05 14:05:50 -03:00
oobabooga
e722c240af Add Instruct mode 2023-04-05 13:54:50 -03:00
oobabooga
3d6cb5ed63 Minor rewrite 2023-04-05 01:21:40 -03:00
oobabooga
f3a2e0b8a9 Disable pre_layer when the model type is not llama 2023-04-05 01:19:26 -03:00
oobabooga
ca8bb38949 Simplify gallery 2023-04-05 00:34:17 -03:00
42 changed files with 1025 additions and 415 deletions

10
.dockerignore Normal file
View File

@@ -0,0 +1,10 @@
.env
Dockerfile
/characters
/extensions
/loras
/models
/presets
/prompts
/softprompts
/training

25
.env.example Normal file
View File

@@ -0,0 +1,25 @@
# by default the Dockerfile specifies these versions: 3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX
# however for me to work i had to specify the exact version for my card ( 2060 ) it was 7.5
# https://developer.nvidia.com/cuda-gpus you can find the version for your card here
TORCH_CUDA_ARCH_LIST=7.5
# these commands worked for me with roughly 4.5GB of vram
CLI_ARGS=--model llama-7b-4bit --wbits 4 --listen --auto-devices
# the following examples have been tested with the files linked in docs/README_docker.md:
# example running 13b with 4bit/128 groupsize : CLI_ARGS=--model llama-13b-4bit-128g --wbits 4 --listen --groupsize 128 --pre_layer 25
# example with loading api extension and public share: CLI_ARGS=--model llama-7b-4bit --wbits 4 --listen --auto-devices --no-stream --extensions api --share
# example running 7b with 8bit groupsize : CLI_ARGS=--model llama-7b --load-in-8bit --listen --auto-devices
# the port the webui binds to on the host
HOST_PORT=7860
# the port the webui binds to inside the container
CONTAINER_PORT=7860
# the port the api binds to on the host
HOST_API_PORT=5000
# the port the api binds to inside the container
CONTAINER_API_PORT=5000
# the version used to install text-generation-webui from
WEBUI_VERSION=HEAD

61
Dockerfile Normal file
View File

@@ -0,0 +1,61 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as builder
RUN apt-get update && \
apt-get install --no-install-recommends -y git vim build-essential python3-dev python3-venv && \
rm -rf /var/lib/apt/lists/*
RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa /build
WORKDIR /build
RUN python3 -m venv /build/venv
RUN . /build/venv/bin/activate && \
pip3 install --upgrade pip setuptools && \
pip3 install torch torchvision torchaudio && \
pip3 install -r requirements.txt
# https://developer.nvidia.com/cuda-gpus
# for a rtx 2060: ARG TORCH_CUDA_ARCH_LIST="7.5"
ARG TORCH_CUDA_ARCH_LIST="3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX"
RUN . /build/venv/bin/activate && \
python3 setup_cuda.py bdist_wheel -d .
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
LABEL maintainer="Your Name <your.email@example.com>"
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
RUN apt-get update && \
apt-get install --no-install-recommends -y git python3 python3-pip && \
rm -rf /var/lib/apt/lists/*
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
COPY . /app/
WORKDIR /app
ARG WEBUI_VERSION
RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Using provided webui source"
RUN virtualenv /app/venv
RUN . /app/venv/bin/activate && \
pip3 install --upgrade pip setuptools && \
pip3 install torch torchvision torchaudio && \
pip3 install -r requirements.txt
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
RUN . /app/venv/bin/activate && \
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
ENV CLI_ARGS=""
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}

View File

@@ -15,6 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* Dropdown menu for switching between models * Dropdown menu for switching between models
* Notebook mode that resembles OpenAI's playground * Notebook mode that resembles OpenAI's playground
* Chat mode for conversation and role playing * Chat mode for conversation and role playing
* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
* Nice HTML output for GPT-4chan * Nice HTML output for GPT-4chan
* Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering * Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
* [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters) * [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
@@ -26,11 +27,11 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* CPU mode * CPU mode
* [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen) * [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
* [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed) * [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming * API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
* [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model) * [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
* [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\*** * [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model) * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs) * [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
* Softprompts * Softprompts
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions) * [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab) * [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
@@ -62,7 +63,7 @@ Recommended if you have some experience with the command-line.
On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide). On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
0. Install Conda #### 0. Install Conda
https://docs.conda.io/en/latest/miniconda.html https://docs.conda.io/en/latest/miniconda.html
@@ -75,14 +76,14 @@ bash Miniconda3.sh
Source: https://educe-ubc.github.io/conda.html Source: https://educe-ubc.github.io/conda.html
1. Create a new conda environment #### 1. Create a new conda environment
``` ```
conda create -n textgen python=3.10.9 conda create -n textgen python=3.10.9
conda activate textgen conda activate textgen
``` ```
2. Install Pytorch #### 2. Install Pytorch
| System | GPU | Command | | System | GPU | Command |
|--------|---------|---------| |--------|---------|---------|
@@ -92,10 +93,12 @@ conda activate textgen
The up to date commands can be found here: https://pytorch.org/get-started/locally/. The up to date commands can be found here: https://pytorch.org/get-started/locally/.
MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393 #### 2.1 Special instructions
* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
* AMD users: https://rentry.org/eq3hg
3. Install the web UI #### 3. Install the web UI
``` ```
git clone https://github.com/oobabooga/text-generation-webui git clone https://github.com/oobabooga/text-generation-webui
@@ -114,8 +117,26 @@ As an alternative to the recommended WSL method, you can install the web UI nati
### Alternative: Docker ### Alternative: Docker
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87 ```
cp .env.example .env
docker-compose up --build
```
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
You need to have docker compose v2.17 or higher installed in your system. For installation instructions, see [Docker compose installation](https://github.com/oobabooga/text-generation-webui/wiki/Docker-compose-installation).
Contributed by [@loeken](https://github.com/loeken) in [#633](https://github.com/oobabooga/text-generation-webui/pull/633)
### Updating the requirements
From time to time, the `requirements.txt` changes. To update, use this command:
```
conda activate textgen
cd text-generation-webui
pip install -r requirements.txt --upgrade
```
## Downloading models ## Downloading models
Models should be placed inside the `models` folder. Models should be placed inside the `models` folder.
@@ -175,7 +196,6 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | show this help message and exit | | `-h`, `--help` | show this help message and exit |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.| | `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to the Character.AI website. |
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--model-dir MODEL_DIR` | Path to directory with all the models | | `--model-dir MODEL_DIR` | Path to directory with all the models |

View File

@@ -17,6 +17,7 @@ def random_hash():
letters = string.ascii_lowercase + string.digits letters = string.ascii_lowercase + string.digits
return ''.join(random.choice(letters) for i in range(9)) return ''.join(random.choice(letters) for i in range(9))
async def run(context): async def run(context):
server = "127.0.0.1" server = "127.0.0.1"
params = { params = {
@@ -36,6 +37,7 @@ async def run(context):
'early_stopping': False, 'early_stopping': False,
'seed': -1, 'seed': -1,
} }
payload = json.dumps([context, params])
session = random_hash() session = random_hash()
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -54,22 +56,7 @@ async def run(context):
"session_hash": session, "session_hash": session,
"fn_index": 12, "fn_index": 12,
"data": [ "data": [
context, payload
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
] ]
})) }))
case "process_starts": case "process_starts":
@@ -83,6 +70,7 @@ async def run(context):
prompt = "What I would like to say is the following: " prompt = "What I would like to say is the following: "
async def get_result(): async def get_result():
async for response in run(prompt): async for response in run(prompt):
# Print intermediate steps # Print intermediate steps

View File

@@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely. allowing you to use the API remotely.
''' '''
import json
import requests import requests
# Server address # Server address
@@ -38,24 +40,11 @@ params = {
# Input prompt # Input prompt
prompt = "What I would like to say is the following: " prompt = "What I would like to say is the following: "
payload = json.dumps([prompt, params])
response = requests.post(f"http://{server}:7860/run/textgen", json={ response = requests.post(f"http://{server}:7860/run/textgen", json={
"data": [ "data": [
prompt, payload
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
] ]
}).json() }).json()

View File

@@ -0,0 +1,3 @@
name: "### Response:"
your_name: "### Instruction:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

View File

@@ -0,0 +1,3 @@
name: "<|assistant|>"
your_name: "<|prompter|>"
end_of_turn: "<|endoftext|>"

View File

@@ -0,0 +1,3 @@
name: "### Assistant:"
your_name: "### Human:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

View File

@@ -17,6 +17,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args() args = parser.parse_args()
def disable_torch_init(): def disable_torch_init():
""" """
Disable the redundant torch default initialization to accelerate model creation. Disable the redundant torch default initialization to accelerate model creation.
@@ -31,12 +32,14 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init(): def restore_torch_init():
"""Rollback the change made by disable_torch_init.""" """Rollback the change made by disable_torch_init."""
import torch import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
if __name__ == '__main__': if __name__ == '__main__':
path = Path(args.MODEL) path = Path(args.MODEL)
model_name = path.name model_name = path.name

View File

@@ -64,6 +64,15 @@
line-height: 1.428571429 !important; line-height: 1.428571429 !important;
} }
.message-body li {
margin-top: 0.5em !important;
margin-bottom: 0.5em !important;
}
.message-body li > p {
display: inline !important;
}
.dark .message-body p em { .dark .message-body p em {
color: rgb(138, 138, 138) !important; color: rgb(138, 138, 138) !important;
} }

View File

@@ -0,0 +1,65 @@
.chat {
margin-left: auto;
margin-right: auto;
max-width: 800px;
height: 66.67vh;
overflow-y: auto;
padding-right: 20px;
display: flex;
flex-direction: column-reverse;
}
.message {
display: grid;
grid-template-columns: 60px 1fr;
padding-bottom: 25px;
font-size: 15px;
font-family: Helvetica, Arial, sans-serif;
line-height: 1.428571429;
}
.username {
display: none;
}
.message-body {}
.message-body p {
margin-bottom: 0 !important;
font-size: 15px !important;
line-height: 1.428571429 !important;
}
.message-body li {
margin-top: 0.5em !important;
margin-bottom: 0.5em !important;
}
.message-body li > p {
display: inline !important;
}
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
}
.message-body p em {
color: rgb(110, 110, 110) !important;
}
.gradio-container .chat .assistant-message {
padding: 15px;
border-radius: 20px;
background-color: #0000000f;
margin-bottom: 17.5px;
}
.gradio-container .chat .user-message {
padding: 15px;
border-radius: 20px;
margin-bottom: 17.5px !important;
}
.dark .chat .assistant-message {
background-color: #ffffff21;
}

View File

@@ -41,7 +41,7 @@ ol li p, ul li p {
display: inline-block; display: inline-block;
} }
#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab { #main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab {
border: 0; border: 0;
} }

32
docker-compose.yml Normal file
View File

@@ -0,0 +1,32 @@
version: "3.3"
services:
text-generation-webui:
build:
context: .
args:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
GPTQ_VERSION: ${GPTQ_VERSION}
WEBUI_VERSION: ${WEBUI_VERSION}
env_file: .env
ports:
- "${HOST_PORT}:${CONTAINER_PORT}"
- "${HOST_API_PORT}:${CONTAINER_API_PORT}"
stdin_open: true
tty: true
volumes:
- ./characters:/app/characters
- ./extensions:/app/extensions
- ./loras:/app/loras
- ./models:/app/models
- ./presets:/app/presets
- ./prompts:/app/prompts
- ./softprompts:/app/softprompts
- ./training:/app/training
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]

View File

@@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args() args = parser.parse_args()
def get_file(url, output_folder): def get_file(url, output_folder):
filename = Path(url.rsplit('/', 1)[1]) filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename output_path = output_folder / filename
@@ -54,6 +55,7 @@ def get_file(url, output_folder):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
def sanitize_branch_name(branch_name): def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$") pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name): if pattern.match(branch_name):
@@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name):
else: else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
def select_model_from_default_options(): def select_model_from_default_options():
models = { models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"), "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped
return model, branch return model, branch
def get_download_links_from_huggingface(model, branch): def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
@@ -172,9 +176,11 @@ def get_download_links_from_huggingface(model, branch):
return links, sha256, is_lora return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8): def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
if __name__ == '__main__': if __name__ == '__main__':
model = args.MODEL model = args.MODEL
branch = args.branch branch = args.branch

View File

@@ -9,6 +9,7 @@ params = {
'port': 5000, 'port': 5000,
} }
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path == '/api/v1/model': if self.path == '/api/v1/model':
@@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
prompt = body['prompt'] prompt = body['prompt']
prompt_lines = [l.strip() for l in prompt.split('\n')] prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048) max_context = body.get('max_context_length', 2048)
@@ -40,24 +41,27 @@ class Handler(BaseHTTPRequestHandler):
prompt_lines.pop(0) prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines) prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
}
generator = generate_reply( generator = generate_reply(
question = prompt, prompt,
max_new_tokens = int(body.get('max_length', 200)), generate_params,
do_sample=bool(body.get('do_sample', True)),
temperature=float(body.get('temperature', 0.5)),
top_p=float(body.get('top_p', 1)),
typical_p=float(body.get('typical', 1)),
repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1,
top_k=int(body.get('top_k', 0)),
min_length=int(body.get('min_length', 0)),
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
num_beams=int(body.get('num_beams',1)),
penalty_alpha=float(body.get('penalty_alpha', 0)),
length_penalty=float(body.get('length_penalty', 1)),
early_stopping=bool(body.get('early_stopping', False)),
seed=int(body.get('seed', -1)),
stopping_strings=body.get('stopping_strings', []), stopping_strings=body.get('stopping_strings', []),
) )
@@ -92,5 +96,6 @@ def run_server():
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever() server.serve_forever()
def setup(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()

View File

@@ -5,6 +5,7 @@ params = {
"bias string": " *I am so happy*", "bias string": " *I am so happy*",
} }
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -13,6 +14,7 @@ def input_modifier(string):
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -20,6 +22,7 @@ def output_modifier(string):
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
behavior. behavior.
""" """
if params['activate'] == True: if params['activate']:
return f'{string} {params["bias string"].strip()} ' return f'{string} {params["bias string"].strip()} '
else: else:
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias') activate = gr.Checkbox(value=params['activate'], label='Activate character bias')

View File

@@ -22,6 +22,8 @@ if not shared.args.no_stream:
raise ValueError raise ValueError
# Check if the API is valid and refresh the UI accordingly. # Check if the API is valid and refresh the UI accordingly.
def check_valid_api(): def check_valid_api():
global user, user_info, params global user, user_info, params
@@ -29,7 +31,7 @@ def check_valid_api():
user = ElevenLabsUser(params['api_key']) user = ElevenLabsUser(params['api_key'])
user_info = user._get_subscription_data() user_info = user._get_subscription_data()
print('checking api') print('checking api')
if params['activate'] == False: if not params['activate']:
return gr.update(value='Disconnected') return gr.update(value='Disconnected')
elif user_info is None: elif user_info is None:
print('Incorrect API Key') print('Incorrect API Key')
@@ -39,6 +41,8 @@ def check_valid_api():
return gr.update(value='Connected') return gr.update(value='Connected')
# Once the API is verified, get the available voices and update the dropdown list # Once the API is verified, get the available voices and update the dropdown list
def refresh_voices(): def refresh_voices():
global user, user_info global user, user_info
@@ -51,11 +55,13 @@ def refresh_voices():
else: else:
return return
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -64,6 +70,7 @@ def input_modifier(string):
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -71,9 +78,9 @@ def output_modifier(string):
global params, wav_idx, user, user_info global params, wav_idx, user, user_info
if params['activate'] == False: if not params['activate']:
return string return string
elif user_info == None: elif user_info is None:
return string return string
string = remove_surrounded_chars(string) string = remove_surrounded_chars(string)
@@ -94,6 +101,7 @@ def output_modifier(string):
wav_idx += 1 wav_idx += 1
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements

View File

@@ -66,13 +66,7 @@ def generate_html():
container_html = '<div class="character-container">' container_html = '<div class="character-container">'
image_html = "<div class='placeholder'></div>" image_html = "<div class='placeholder'></div>"
for i in [ for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
f"characters/{character}.png",
f"characters/{character}.jpg",
f"characters/{character}.jpeg",
]:
path = Path(i)
if path.exists(): if path.exists():
image_html = f'<img src="file/{get_image_cache(path)}">' image_html = f'<img src="file/{get_image_cache(path)}">'
break break

View File

@@ -7,6 +7,7 @@ params = {
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'} language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -15,6 +16,7 @@ def input_modifier(string):
return GoogleTranslator(source=params['language string'], target='en').translate(string) return GoogleTranslator(source=params['language string'], target='en').translate(string)
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -22,6 +24,7 @@ def output_modifier(string):
return GoogleTranslator(source='en', target=params['language string']).translate(string) return GoogleTranslator(source='en', target=params['language string']).translate(string)
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
return string return string
def ui(): def ui():
# Finding the language name from the language code to use as the default value # Finding the language name from the language code to use as the default value
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])] language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]

View File

@@ -4,12 +4,14 @@ import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name): def get_prompt_by_name(name):
if name == 'None': if name == 'None':
return '' return ''
else: else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui(): def ui():
if not shared.is_chat(): if not shared.is_chat():
choices = ['None'] + list(df['Prompt name']) choices = ['None'] + list(df['Prompt name'])

View File

@@ -30,12 +30,15 @@ streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0 pic_id = 0
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -62,6 +65,8 @@ def input_modifier(string):
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params, pic_id global params, pic_id
@@ -101,6 +106,8 @@ def get_SD_pictures(description):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -130,6 +137,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state shared.args.no_stream = streaming_state
return image + "\n" + text return image + "\n" + text
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
return string return string
def force_pic(): def force_pic():
global picture_response global picture_response
picture_response = True picture_response = True
def ui(): def ui():
# Gradio elements # Gradio elements
@@ -176,4 +186,4 @@ def ui():
force_btn.click(force_pic) force_btn.click(force_pic)
generate_now_btn.click(force_pic) generate_now_btn.click(force_pic)
generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

View File

@@ -2,12 +2,11 @@ import base64
from io import BytesIO from io import BytesIO
import gradio as gr import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor from transformers import BlipForConditionalGeneration, BlipProcessor
from modules import chat, shared
# If 'state' is True, will hijack the next chat generation with # If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text] # custom input text given by 'value' in the format [text, visible_text]
input_hijack = { input_hijack = {
@@ -18,11 +17,13 @@ input_hijack = {
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def caption_image(raw_image): def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100) out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True) return processor.decode(out[0], skip_special_tokens=True)
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
@@ -33,16 +34,15 @@ def generate_chat_picture(picture, name1, name2):
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">' visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
return text, visible_text return text, visible_text
def ui(): def ui():
picture_select = gr.Image(label='Send a picture', type='pil') picture_select = gr.Image(label='Send a picture', type='pil')
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
# Prepare the hijack with custom inputs # Prepare the hijack with custom inputs
picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
# Call the generation function # Call the generation function
picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field # Clear the picture from the upload field
picture_select.upload(lambda: None, [], [picture_select], show_progress=False) picture_select.upload(lambda: None, [], [picture_select], show_progress=False)

View File

@@ -1,3 +1,4 @@
import inspect
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
@@ -15,10 +16,12 @@ from modelutils import find_layers
from quant import make_quant from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, exclude_layers=['lm_head']): def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
config = AutoConfig.from_pretrained(model)
torch.nn.init.kaiming_uniform_ = noop torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop torch.nn.init.normal_ = noop
@@ -33,7 +36,22 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, exclude_layers=['lm_head
for name in exclude_layers: for name in exclude_layers:
if name in layers: if name in layers:
del layers[name] del layers[name]
make_quant(model, layers, wbits, groupsize)
gptq_args = inspect.getfullargspec(make_quant).args
make_quant_kwargs = {
'module': model,
'names': layers,
'bits': wbits,
}
if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant(**make_quant_kwargs)
del layers del layers
@@ -48,6 +66,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, exclude_layers=['lm_head
return model return model
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name
@@ -65,9 +84,11 @@ def load_quantized(model_name):
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()
if model_type == 'llama' and shared.args.pre_layer: if shared.args.pre_layer and model_type == 'llama':
load_quant = llama_inference_offload.load_quant load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'): elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant load_quant = _load_quant
else: else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
@@ -107,10 +128,11 @@ def load_quantized(model_name):
exit() exit()
# qwopqwop200's offload # qwopqwop200's offload
if shared.args.pre_layer: if model_type == 'llama' and shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
else: else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize) threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)
# accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory:

View File

@@ -13,6 +13,7 @@ def reload_model():
clear_torch_cache() clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):
# If a LoRA had been previously loaded, or if we want # If a LoRA had been previously loaded, or if we want

View File

@@ -54,6 +54,7 @@ class RWKVModel:
reply += token reply += token
yield reply yield reply
class RWKVTokenizer: class RWKVTokenizer:
def __init__(self): def __init__(self):
pass pass

39
modules/api.py Normal file
View File

@@ -0,0 +1,39 @@
import json
import gradio as gr
from modules import shared
from modules.text_generation import generate_reply
def generate_reply_wrapper(string):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 50,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
}
params = json.loads(string)
for k in params[1]:
generate_params[k] = params[1][k]
for i in generate_reply(params[0], generate_params):
yield i
def create_apis():
t1 = gr.Textbox(visible=False)
t2 = gr.Textbox(visible=False)
dummy = gr.Button(visible=False)
input_params = [t1]
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

View File

@@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
return True return True
return False return False
class Stream(transformers.StoppingCriteria): class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None): def __init__(self, callback_func=None):
self.callback_func = callback_func self.callback_func = callback_func
@@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
self.callback_func(input_ids[0]) self.callback_func(input_ids[0])
return False return False
class Iteratorize: class Iteratorize:
""" """
@@ -96,6 +98,7 @@ class Iteratorize:
self.stop_now = True self.stop_now = True
clear_torch_cache() clear_torch_cache()
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
if not shared.args.cpu: if not shared.args.cpu:

View File

@@ -12,46 +12,54 @@ from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
import modules.shared as shared import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import (fix_newlines, generate_chat_html, from modules.html_generator import (fix_newlines, chat_html_wrapper,
make_thumbnail) make_thumbnail)
from modules.text_generation import (encode, generate_reply, from modules.text_generation import (encode, generate_reply,
get_max_prompt_length) get_max_prompt_length)
def generate_chat_output(history, name1, name2): def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
if shared.args.cai_chat: is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
return generate_chat_html(history, name1, name2) end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
else: impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
return history also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
# Finding the maximum prompt size
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
if is_instruct:
prefix1 = f"{name1}\n"
prefix2 = f"{name2}\n"
else:
prefix1 = f"{name1}: "
prefix2 = f"{name2}: "
i = len(shared.history['internal']) - 1 i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
prev_user_input = shared.history['internal'][i][0] string = shared.history['internal'][i][0]
if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{name1}: {prev_user_input.strip()}\n") rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
i -= 1 i -= 1
if not impersonate: if impersonate:
if len(user_input) > 0: rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
rows.append(f"{name1}: {user_input}\n")
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
limit = 3
else:
rows.append(f"{name1}:")
limit = 2 limit = 2
else:
# Adding the user message
user_input = fix_newlines(user_input)
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows) prompt = ''.join(rows)
if also_return_rows: if also_return_rows:
@@ -59,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
else: else:
return prompt return prompt
def extract_message_from_reply(reply, name1, name2, stop_at_newline): def extract_message_from_reply(reply, name1, name2, stop_at_newline):
next_character_found = False next_character_found = False
@@ -82,13 +91,21 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
if reply[-j:] == string[:j]: if reply[-j:] == string[:j]:
reply = reply[:-j] reply = reply[:-j]
break break
else:
continue
break
reply = fix_newlines(reply) reply = fix_newlines(reply)
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
just_started = True def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
eos_token = '\n' if stop_at_newline else None if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
eos_token = '\n' if generate_state['stop_at_newline'] else None
name1_original = name1 name1_original = name1
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
@@ -97,7 +114,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
visible_text = None visible_text = None
custom_generate_chat_prompt = None custom_generate_chat_prompt = None
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False extension.input_hijack['state'] = False
text, visible_text = extension.input_hijack['value'] text, visible_text = extension.input_hijack['value']
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
@@ -105,14 +122,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
if shared.args.chat:
visible_text = visible_text.replace('\n', '<br>')
text = apply_extensions(text, "input") text = apply_extensions(text, "input")
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
else: else:
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
# Yield *Is typing...* # Yield *Is typing...*
if not regenerate: if not regenerate:
@@ -120,17 +136,16 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate # Generate
cumulative_reply = '' cumulative_reply = ''
for i in range(chat_generation_attempts): just_started = True
for i in range(generate_state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
if shared.args.chat:
visible_reply = visible_reply.replace('\n', '<br>')
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
@@ -153,23 +168,28 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
eos_token = '\n' if stop_at_newline else None
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower(): if 'pygmalion' in shared.model_name.lower():
name1 = "You" name1 = "You"
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
cumulative_reply = '' cumulative_reply = ''
for i in range(chat_generation_attempts): for i in range(generate_state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
yield reply yield reply
if next_character_found: if next_character_found:
break break
@@ -179,36 +199,34 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
yield reply yield reply
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
yield generate_chat_html(history, name1, name2)
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
yield chat_html_wrapper(history, name1, name2, mode)
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield generate_chat_output(shared.history['visible'], name1, name2) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
else: else:
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*' # Yield '*Is typing...*'
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2) yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True): for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
else: yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
shared.history['visible'][-1] = (last_visible[0], history[-1][1])
yield generate_chat_output(shared.history['visible'], name1, name2)
def remove_last_message(name1, name2):
def remove_last_message(name1, name2, mode):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = shared.history['visible'].pop()
shared.history['internal'].pop() shared.history['internal'].pop()
else: else:
last = ['', ''] last = ['', '']
if shared.args.cai_chat: return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
return generate_chat_html(shared.history['visible'], name1, name2), last[0]
else:
return shared.history['visible'], last[0]
def send_last_reply_to_input(): def send_last_reply_to_input():
if len(shared.history['internal']) > 0: if len(shared.history['internal']) > 0:
@@ -216,20 +234,20 @@ def send_last_reply_to_input():
else: else:
return '' return ''
def replace_last_reply(text, name1, name2):
def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0: if len(shared.history['visible']) > 0:
if shared.args.cai_chat:
shared.history['visible'][-1][1] = text shared.history['visible'][-1][1] = text
else:
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
shared.history['internal'][-1][1] = apply_extensions(text, "input") shared.history['internal'][-1][1] = apply_extensions(text, "input")
return generate_chat_output(shared.history['visible'], name1, name2) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def clear_html(): def clear_html():
return generate_chat_html([], "", "") return chat_html_wrapper([], "", "")
def clear_chat_log(name1, name2, greeting):
def clear_chat_log(name1, name2, greeting, mode):
shared.history['visible'] = [] shared.history['visible'] = []
shared.history['internal'] = [] shared.history['internal'] = []
@@ -237,12 +255,14 @@ def clear_chat_log(name1, name2, greeting):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return generate_chat_output(shared.history['visible'], name1, name2) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def redraw_html(name1, name2):
return generate_chat_html(shared.history['visible'], name1, name2)
def tokenize_dialogue(dialogue, name1, name2): def redraw_html(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def tokenize_dialogue(dialogue, name1, name2, mode):
history = [] history = []
dialogue = re.sub('<START>', '', dialogue) dialogue = re.sub('<START>', '', dialogue)
@@ -279,6 +299,7 @@ def tokenize_dialogue(dialogue, name1, name2):
return history return history
def save_history(timestamp=True): def save_history(timestamp=True):
if timestamp: if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
@@ -290,6 +311,7 @@ def save_history(timestamp=True):
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
def load_history(file, name1, name2): def load_history(file, name1, name2):
file = file.decode('utf-8') file = file.decode('utf-8')
try: try:
@@ -314,10 +336,12 @@ def load_history(file, name1, name2):
shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def replace_character_names(text, name1, name2): def replace_character_names(text, name1, name2):
text = text.replace('{{user}}', name1).replace('{{char}}', name2) text = text.replace('{{user}}', name1).replace('{{char}}', name2)
return text.replace('<USER>', name1).replace('<BOT>', name2) return text.replace('<USER>', name1).replace('<BOT>', name2)
def build_pygmalion_style_context(data): def build_pygmalion_style_context(data):
context = "" context = ""
if 'char_persona' in data and data['char_persona'] != '': if 'char_persona' in data and data['char_persona'] != '':
@@ -327,6 +351,7 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
return context return context
def generate_pfp_cache(character): def generate_pfp_cache(character):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@@ -339,11 +364,13 @@ def generate_pfp_cache(character):
return img return img
return None return None
def load_character(character, name1, name2):
def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
shared.history['internal'] = [] shared.history['internal'] = []
shared.history['visible'] = [] shared.history['visible'] = []
greeting = "" context = greeting = end_of_turn = ""
greeting_field = 'greeting'
picture = None picture = None
# Deleting the profile picture cache, if any # Deleting the profile picture cache, if any
@@ -351,9 +378,10 @@ def load_character(character, name1, name2):
Path("cache/pfp_character.png").unlink() Path("cache/pfp_character.png").unlink()
if character != 'None': if character != 'None':
folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
picture = generate_pfp_cache(character) picture = generate_pfp_cache(character)
for extension in ["yml", "yaml", "json"]: for extension in ["yml", "yaml", "json"]:
filepath = Path(f'characters/{character}.{extension}') filepath = Path(f'{folder}/{character}.{extension}')
if filepath.exists(): if filepath.exists():
break break
file_contents = open(filepath, 'r', encoding='utf-8').read() file_contents = open(filepath, 'r', encoding='utf-8').read()
@@ -369,19 +397,21 @@ def load_character(character, name1, name2):
if 'context' in data: if 'context' in data:
context = f"{data['context'].strip()}\n\n" context = f"{data['context'].strip()}\n\n"
greeting_field = 'greeting' elif "char_persona" in data:
else:
context = build_pygmalion_style_context(data) context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting' greeting_field = 'char_greeting'
if 'example_dialogue' in data and data['example_dialogue'] != '': if 'example_dialogue' in data:
context += f"{data['example_dialogue'].strip()}\n" context += f"{data['example_dialogue'].strip()}\n"
if greeting_field in data and len(data[greeting_field].strip()) > 0: if greeting_field in data:
greeting = data[greeting_field] greeting = data[greeting_field]
if 'end_of_turn' in data:
end_of_turn = data['end_of_turn']
else: else:
context = shared.settings['context'] context = shared.settings['context']
name2 = shared.settings['name2'] name2 = shared.settings['name2']
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn']
if Path(f'logs/{shared.character}_persistent.json').exists(): if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
@@ -389,13 +419,12 @@ def load_character(character, name1, name2):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
if shared.args.cai_chat: return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True)
else:
return name1, name2, picture, greeting, context, shared.history['visible']
def load_default_history(name1, name2): def load_default_history(name1, name2):
load_character("None", name1, name2) load_character("None", name1, name2, "chat")
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
@@ -415,6 +444,7 @@ def upload_character(json_file, img, tavern=False):
print(f'New character saved to "characters/{outfile_name}.json".') print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name
def upload_tavern_character(img, name1, name2): def upload_tavern_character(img, name1, name2):
_img = Image.open(io.BytesIO(img)) _img = Image.open(io.BytesIO(img))
_img.getexif() _img.getexif()
@@ -423,12 +453,13 @@ def upload_tavern_character(img, name1, name2):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True) return upload_character(json.dumps(_json), img, tavern=True)
def upload_your_profile_picture(img, name1, name2):
def upload_your_profile_picture(img, name1, name2, mode):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
cache_folder.mkdir() cache_folder.mkdir()
if img == None: if img is None:
if Path("cache/pfp_me.png").exists(): if Path("cache/pfp_me.png").exists():
Path("cache/pfp_me.png").unlink() Path("cache/pfp_me.png").unlink()
else: else:
@@ -436,7 +467,4 @@ def upload_your_profile_picture(img, name1, name2):
img.save(Path('cache/pfp_me.png')) img.save(Path('cache/pfp_me.png'))
print('Profile picture saved to "cache/pfp_me.png"') print('Profile picture saved to "cache/pfp_me.png"')
if shared.args.cai_chat: return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
return generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True)
else:
return shared.history['visible']

View File

@@ -9,6 +9,7 @@ state = {}
available_extensions = [] available_extensions = []
setup_called = set() setup_called = set()
def load_extensions(): def load_extensions():
global state global state
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
@@ -23,12 +24,16 @@ def load_extensions():
traceback.print_exc() traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
for name in sorted(state, key=lambda x: state[x][1]): for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0] == True: if state[name][0] == True:
yield eval(f"extensions.{name}.script"), name yield eval(f"extensions.{name}.script"), name
# Extension functions that map string -> string # Extension functions that map string -> string
def apply_extensions(text, typ): def apply_extensions(text, typ):
for extension, _ in iterator(): for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"): if typ == "input" and hasattr(extension, "input_modifier"):
@@ -39,6 +44,7 @@ def apply_extensions(text, typ):
text = extension.bot_prefix_modifier(text) text = extension.bot_prefix_modifier(text)
return text return text
def create_extensions_block(): def create_extensions_block():
global setup_called global setup_called

View File

@@ -21,6 +21,9 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
_4chan_css = css_f.read() _4chan_css = css_f.read()
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
cai_css = f.read() cai_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
instruct_css = f.read()
def fix_newlines(string): def fix_newlines(string):
string = string.replace('\n', '\n\n') string = string.replace('\n', '\n\n')
@@ -29,6 +32,8 @@ def fix_newlines(string):
return string return string
# This could probably be generalized and improved # This could probably be generalized and improved
def convert_to_markdown(string): def convert_to_markdown(string):
string = string.replace('\\begin{code}', '```') string = string.replace('\\begin{code}', '```')
string = string.replace('\\end{code}', '```') string = string.replace('\\end{code}', '```')
@@ -38,11 +43,13 @@ def convert_to_markdown(string):
string = fix_newlines(string) string = fix_newlines(string)
return markdown.markdown(string, extensions=['fenced_code']) return markdown.markdown(string, extensions=['fenced_code'])
def generate_basic_html(string): def generate_basic_html(string):
string = convert_to_markdown(string) string = convert_to_markdown(string)
string = f'<style>{readable_css}</style><div class="container">{string}</div>' string = f'<style>{readable_css}</style><div class="container">{string}</div>'
return string return string
def process_post(post, c): def process_post(post, c):
t = post.split('\n') t = post.split('\n')
number = t[0].split(' ')[1] number = t[0].split(' ')[1]
@@ -57,6 +64,7 @@ def process_post(post, c):
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}' src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
return src return src
def generate_4chan_html(f): def generate_4chan_html(f):
posts = [] posts = []
post = '' post = ''
@@ -96,6 +104,7 @@ def generate_4chan_html(f):
return output return output
def make_thumbnail(image): def make_thumbnail(image):
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS) image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
if image.size[1] > 470: if image.size[1] > 470:
@@ -103,6 +112,7 @@ def make_thumbnail(image):
return image return image
def get_image_cache(path): def get_image_cache(path):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@@ -117,11 +127,45 @@ def get_image_cache(path):
return image_cache[path][1] return image_cache[path][1]
def generate_chat_html(history, name1, name2, reset_cache=False):
def generate_instruct_html(history):
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
for i, _row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
output += f"""
<div class="assistant-message">
<div class="text">
<div class="message-body">
{row[1]}
</div>
</div>
</div>
"""
if len(row[0]) == 0: # don't display empty user messages
continue
output += f"""
<div class="user-message">
<div class="text">
<div class="message-body">
{row[0]}
</div>
</div>
</div>
"""
output += "</div>"
return output
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
# The time.time() is to prevent the brower from caching the image # The time.time() is to prevent the brower from caching the image
suffix = f"?{time.time()}" if reset_cache else '' suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else '' img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else '' img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
@@ -165,3 +209,18 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
output += "</div>" output += "</div>"
return output return output
def generate_chat_html(history, name1, name2):
return generate_cai_chat_html(history, name1, name2)
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache)
elif mode == "chat":
return generate_chat_html(history, name1, name2)
elif mode == "instruct":
return generate_instruct_html(history)
else:
return ''

View File

@@ -0,0 +1,63 @@
'''
Based on
https://github.com/abetlen/llama-cpp-python
Documentation:
https://abetlen.github.io/llama-cpp-python/
'''
from llama_cpp import Llama
from modules import shared
from modules.callbacks import Iteratorize
class LlamaCppModel:
def __init__(self):
self.initialized = False
@classmethod
def from_pretrained(self, path):
result = self()
params = {
'model_path': str(path),
'n_ctx': 2048,
'seed': 0,
'n_threads': shared.args.threads or None
}
self.model = Llama(**params)
# This is ugly, but the model and the tokenizer are the same object in this library.
return result, result
def encode(self, string):
if type(string) is str:
string = string.encode()
return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
if type(context) is str:
context = context.encode()
tokens = self.model.tokenize(context)
output = b""
count = 0
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
text = self.model.detokenize([token])
output += text
if callback:
callback(text.decode())
count += 1
if count >= token_count or (token == self.model.token_eos()):
break
return output.decode()
def generate_with_streaming(self, **kwargs):
with Iteratorize(self.generate, kwargs, callback=None) as generator:
reply = ''
for token in generator:
reply += token
yield reply

View File

@@ -10,7 +10,7 @@ import torch
import transformers import transformers
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig) BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared import modules.shared as shared
@@ -103,7 +103,7 @@ def load_model(model_name):
# llamacpp model # llamacpp model
elif shared.is_llamacpp: elif shared.is_llamacpp:
from modules.llamacpp_model import LlamaCppModel from modules.llamacpp_model_alternative import LlamaCppModel
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0] model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n") print(f"llama.cpp weights detected: {model_file}\n")
@@ -172,6 +172,8 @@ def load_model(model_name):
# Loading the tokenizer # Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left' tokenizer.truncation_side = 'left'
@@ -179,6 +181,7 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
def load_soft_prompt(name): def load_soft_prompt(name):
if name == 'None': if name == 'None':
shared.soft_prompt = False shared.soft_prompt = False

View File

@@ -33,6 +33,7 @@ settings = {
'name2': 'Assistant', 'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': 'Hello there!', 'greeting': 'Hello there!',
'end_of_turn': '',
'stop_at_newline': False, 'stop_at_newline': False,
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
@@ -44,6 +45,7 @@ settings = {
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {
'default': 'NovelAI-Sphinx Moth', 'default': 'NovelAI-Sphinx Moth',
'.*(alpaca|llama)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter', '.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive', '.*RWKV': 'Naive',
}, },
@@ -59,6 +61,7 @@ settings = {
} }
} }
def str2bool(v): def str2bool(v):
if isinstance(v, bool): if isinstance(v, bool):
return v return v
@@ -69,12 +72,13 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
# Basic settings # Basic settings
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
@@ -131,12 +135,18 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
args = parser.parse_args() args = parser.parse_args()
# Provisional, this will be deleted later # Deprecation warnings for parameters that have been renamed
deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]} deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
for k in deprecated_dict: for k in deprecated_dict:
if eval(f"args.{k}") != deprecated_dict[k][1]: if eval(f"args.{k}") != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
exec(f"args.{deprecated_dict[k][0]} = args.{k}") exec(f"args.{deprecated_dict[k][0]} = args.{k}")
# Deprecation warnings for parameters that have been removed
if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.")
args.chat = True
def is_chat(): def is_chat():
return any((args.chat, args.cai_chat)) return args.chat

View File

@@ -21,6 +21,7 @@ def get_max_prompt_length(tokens):
max_length -= shared.soft_prompt_tensor.shape[1] max_length -= shared.soft_prompt_tensor.shape[1]
return max_length return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
@@ -28,6 +29,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
return input_ids return input_ids
else: else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:]
if shared.args.cpu: if shared.args.cpu:
return input_ids return input_ids
elif shared.args.flexgen: elif shared.args.flexgen:
@@ -40,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else: else:
return input_ids.cuda() return input_ids.cuda()
def decode(output_ids): def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|> # Open Assistant relies on special tokens like <|endoftext|>
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
@@ -49,6 +55,7 @@ def decode(output_ids):
reply = reply.replace(r'<|endoftext|>', '') reply = reply.replace(r'<|endoftext|>', '')
return reply return reply
def generate_softprompt_input_tensors(input_ids): def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids) inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
@@ -57,6 +64,8 @@ def generate_softprompt_input_tensors(input_ids):
return inputs_embeds, filler_input_ids return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@@ -65,6 +74,8 @@ def fix_gpt4chan(s):
return s return s
# Fix the LaTeX equations in galactica # Fix the LaTeX equations in galactica
def fix_galactica(s): def fix_galactica(s):
s = s.replace(r'\[', r'$') s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$') s = s.replace(r'\]', r'$')
@@ -75,6 +86,7 @@ def fix_galactica(s):
s = re.sub(r"\n{3,}", "\n\n", s) s = re.sub(r"\n{3,}", "\n\n", s)
return s return s
def formatted_outputs(reply, model_name): def formatted_outputs(reply, model_name):
if not shared.is_chat(): if not shared.is_chat():
if 'galactica' in model_name.lower(): if 'galactica' in model_name.lower():
@@ -88,24 +100,29 @@ def formatted_outputs(reply, model_name):
else: else:
return reply return reply
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
if not shared.args.cpu: if not shared.args.cpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def set_manual_seed(seed): def set_manual_seed(seed):
if seed != -1: if seed != -1:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(seed) set_manual_seed(generate_state['seed'])
shared.stop_everything = False shared.stop_everything = False
generate_params = {}
t0 = time.time() t0 = time.time()
original_question = question original_question = question
@@ -117,9 +134,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them # These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier # separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k]
generate_params["token_count"] = generate_state["max_new_tokens"]
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply output = original_question + reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
@@ -130,7 +150,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# RWKV has proper streaming, which is very nice. # RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time. # No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty): for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply output = original_question + reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output") reply = original_question + apply_extensions(reply, "output")
@@ -145,7 +165,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return return
input_ids = encode(question, max_new_tokens) input_ids = encode(question, generate_state['max_new_tokens'])
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
@@ -158,33 +178,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings] t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
generate_params = {} generate_params["max_new_tokens"] = generate_state['max_new_tokens']
if not shared.args.flexgen: if not shared.args.flexgen:
generate_params.update({ for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
"max_new_tokens": max_new_tokens, generate_params[k] = generate_state[k]
"eos_token_id": eos_token_ids, generate_params["eos_token_id"] = eos_token_ids
"stopping_criteria": stopping_criteria_list, generate_params["stopping_criteria"] = stopping_criteria_list
"do_sample": do_sample, if shared.args.no_stream:
"temperature": temperature, generate_params["min_length"] = 0
"top_p": top_p,
"typical_p": typical_p,
"repetition_penalty": repetition_penalty,
"encoder_repetition_penalty": encoder_repetition_penalty,
"top_k": top_k,
"min_length": min_length if shared.args.no_stream else 0,
"no_repeat_ngram_size": no_repeat_ngram_size,
"num_beams": num_beams,
"penalty_alpha": penalty_alpha,
"length_penalty": length_penalty,
"early_stopping": early_stopping,
})
else: else:
generate_params.update({ for k in ["do_sample", "temperature"]:
"max_new_tokens": max_new_tokens if shared.args.no_stream else 8, generate_params[k] = generate_state[k]
"do_sample": do_sample, generate_params["stop"] = generate_state["eos_token_ids"][-1]
"temperature": temperature, if not shared.args.no_stream:
"stop": eos_token_ids[-1], generate_params["max_new_tokens"] = 8
})
if shared.args.no_cache: if shared.args.no_cache:
generate_params.update({"use_cache": False}) generate_params.update({"use_cache": False})
if shared.args.deepspeed: if shared.args.deepspeed:
@@ -244,7 +252,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
for i in range(max_new_tokens//8+1): for i in range(generate_state['max_new_tokens'] // 8 + 1):
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = shared.model.generate(**generate_params)[0] output = shared.model.generate(**generate_params)[0]

View File

@@ -19,8 +19,10 @@ CURRENT_STEPS = 0
MAX_STEPS = 0 MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1 CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str): def get_dataset(path: str, ext: str):
return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower) return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
def create_train_interface(): def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'): with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
@@ -45,28 +47,34 @@ def create_train_interface():
with gr.Row(): with gr.Row():
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.') dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.') eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button') ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
with gr.Tab(label="Raw Text File"): with gr.Tab(label="Raw Text File"):
with gr.Row(): with gr.Row():
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button') ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.') with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
with gr.Row(): with gr.Row():
start_button = gr.Button("Start LoRA Training") start_button = gr.Button("Start LoRA Training")
stop_button = gr.Button("Interrupt") stop_button = gr.Button("Interrupt")
output = gr.Markdown(value="Ready") output = gr.Markdown(value="Ready")
start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output]) start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False) stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt(): def do_interrupt():
global WANT_INTERRUPT global WANT_INTERRUPT
WANT_INTERRUPT = True WANT_INTERRUPT = True
class Callbacks(transformers.TrainerCallback): class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS global CURRENT_STEPS, MAX_STEPS
@@ -75,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
if WANT_INTERRUPT: if WANT_INTERRUPT:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS global CURRENT_STEPS
CURRENT_STEPS += 1 CURRENT_STEPS += 1
@@ -82,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def clean_path(base_path: str, path: str): def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@@ -91,8 +101,9 @@ def clean_path(base_path: str, path: str):
return path return path
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int): def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
WANT_INTERRUPT = False WANT_INTERRUPT = False
CURRENT_STEPS = 0 CURRENT_STEPS = 0
@@ -103,6 +114,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}" lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
actual_lr = float(learning_rate) actual_lr = float(learning_rate)
model_type = type(shared.model).__name__
if model_type != "LlamaForCausalLM":
if model_type == "PeftModelForCausalLM":
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
time.sleep(5)
if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
return
elif not shared.args.load_in_8bit:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes." yield "Cannot input zeroes."
return return
@@ -126,15 +156,20 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
raw_text = file.read() raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text) tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)): for i in range(1, len(tokens)):
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
text_chunks = [shared.tokenizer.decode(x) for x in tokens] text_chunks = [shared.tokenizer.decode(x) for x in tokens]
del tokens del tokens
data = Dataset.from_list([tokenize(x) for x in text_chunks])
train_data = data.shuffle() if newline_favor_len > 0:
eval_data = None text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks del text_chunks
train_data = train_data.shuffle()
eval_data = None
else: else:
if dataset in ['None', '']: if dataset in ['None', '']:
@@ -232,33 +267,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
# TODO: save/load checkpoints to resume from? # TODO: save/load checkpoints to resume from?
print("Starting training...") print("Starting training...")
yield "Starting..." yield "Starting..."
if WANT_INTERRUPT:
yield "Interrupted before start."
return
def threadedRun(): def threaded_run():
trainer.train() trainer.train()
thread = threading.Thread(target=threadedRun) thread = threading.Thread(target=threaded_run)
thread.start() thread.start()
lastStep = 0 last_step = 0
startTime = time.perf_counter() start_time = time.perf_counter()
while thread.is_alive(): while thread.is_alive():
time.sleep(0.5) time.sleep(0.5)
if WANT_INTERRUPT: if WANT_INTERRUPT:
yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*" yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
elif CURRENT_STEPS != lastStep:
lastStep = CURRENT_STEPS elif CURRENT_STEPS != last_step:
timeElapsed = time.perf_counter() - startTime last_step = CURRENT_STEPS
if timeElapsed <= 0: time_elapsed = time.perf_counter() - start_time
timerInfo = "" if time_elapsed <= 0:
totalTimeEstimate = 999 timer_info = ""
total_time_estimate = 999
else: else:
its = CURRENT_STEPS / timeElapsed its = CURRENT_STEPS / time_elapsed
if its > 1: if its > 1:
timerInfo = f"`{its:.2f}` it/s" timer_info = f"`{its:.2f}` it/s"
else: else:
timerInfo = f"`{1.0/its:.2f}` s/it" timer_info = f"`{1.0/its:.2f}` s/it"
totalTimeEstimate = (1.0/its) * (MAX_STEPS) total_time_estimate = (1.0 / its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds" yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
print("Training complete, saving...") print("Training complete, saving...")
lora_model.save_pretrained(lora_name) lora_model.save_pretrained(lora_name)
@@ -270,6 +309,31 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
print("Training complete!") print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`" yield f"Done! LoRA saved to `{lora_name}`"
def split_chunks(arr, step): def split_chunks(arr, step):
for i in range(0, len(arr), step): for i in range(0, len(arr), step):
yield arr[i:i + step] yield arr[i:i + step]
def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk
first_newline = chunk.index('\n')
if first_newline < max_length:
chunk = chunk[first_newline + 1:]
if '\n' not in chunk:
return chunk
last_newline = chunk.rindex('\n')
if len(chunk) - last_newline < max_length:
chunk = chunk[:last_newline]
return chunk
def format_time(seconds: float):
if seconds < 120:
return f"`{seconds:.0f}` seconds"
minutes = seconds / 60
if minutes < 120:
return f"`{minutes:.0f}` minutes"
hours = minutes / 60
return f"`{hours:.0f}` hours"

View File

@@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read() chat_js = f.read()
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms""" """Small button with single emoji as text, fits inside gradio forms"""
@@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
def get_block_name(self): def get_block_name(self):
return "button" return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()

View File

@@ -0,0 +1,6 @@
do_sample=True
top_p=0.1
top_k=40
temperature=0.7
repetition_penalty=1.18
typical_p=1.0

View File

@@ -3,14 +3,13 @@ bitsandbytes==0.37.2
datasets datasets
flexgen==0.1.7 flexgen==0.1.7
gradio==3.24.1 gradio==3.24.1
llamacpp==0.1.11
markdown markdown
numpy numpy
peft==0.2.0 peft==0.2.0
requests requests
rwkv==0.7.2 rwkv==0.7.3
safetensors==0.3.0 safetensors==0.3.0
sentencepiece sentencepiece
pyyaml pyyaml
tqdm tqdm
git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0 git+https://github.com/huggingface/transformers

166
server.py
View File

@@ -1,3 +1,7 @@
import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import io import io
import json import json
import re import re
@@ -11,8 +15,8 @@ import gradio as gr
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import chat, shared, training, ui from modules import chat, shared, training, ui, api
from modules.html_generator import generate_chat_html from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import (clear_torch_cache, generate_reply, from modules.text_generation import (clear_torch_cache, generate_reply,
@@ -30,15 +34,18 @@ if settings_file is not None:
for item in new_settings: for item in new_settings:
shared.settings[item] = new_settings[item] shared.settings[item] = new_settings[item]
def get_available_models(): def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower) return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
def get_available_prompts(): def get_available_prompts():
prompts = [] prompts = []
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
@@ -46,23 +53,37 @@ def get_available_prompts():
prompts += ['None'] prompts += ['None']
return prompts return prompts
def get_available_characters(): def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
def get_available_instruction_templates():
path = "characters/instruction-following"
paths = []
if os.path.exists(path):
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
def get_available_extensions(): def get_available_extensions():
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts(): def get_available_softprompts():
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras(): def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model(): def unload_model():
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
clear_torch_cache() clear_torch_cache()
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
@@ -73,11 +94,13 @@ def load_model_wrapper(selected_model):
return selected_model return selected_model
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
return selected_lora return selected_lora
def load_preset_values(preset_menu, return_dict=False):
def load_preset_values(preset_menu, state, return_dict=False):
generate_params = { generate_params = {
'do_sample': True, 'do_sample': True,
'temperature': 1, 'temperature': 1,
@@ -99,13 +122,14 @@ def load_preset_values(preset_menu, return_dict=False):
i = i.rstrip(',').strip().split('=') i = i.rstrip(',').strip().split('=')
if len(i) == 2 and i[0].strip() != 'tokens': if len(i) == 2 and i[0].strip() != 'tokens':
generate_params[i[0].strip()] = eval(i[1].strip()) generate_params[i[0].strip()] = eval(i[1].strip())
generate_params['temperature'] = min(1.99, generate_params['temperature']) generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict: if return_dict:
return generate_params return generate_params
else: else:
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
def upload_soft_prompt(file): def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf: with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -119,23 +143,14 @@ def upload_soft_prompt(file):
return name return name
def create_model_and_preset_menus():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
with gr.Column():
with gr.Row():
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
def save_prompt(text): def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt" fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
return f"Saved to prompts/{fname}" return f"Saved to prompts/{fname}"
def load_prompt(fname): def load_prompt(fname):
if fname in ['None', '']: if fname in ['None', '']:
return '' return ''
@@ -146,6 +161,7 @@ def load_prompt(fname):
text = text[:-1] text = text[:-1]
return text return text
def create_prompt_menus(): def create_prompt_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -161,12 +177,33 @@ def create_prompt_menus():
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_model_menus():
with gr.Row():
with gr.Column():
with gr.Row():
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
with gr.Column():
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params)
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
create_model_and_preset_menus() with gr.Row():
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
with gr.Column(): with gr.Column():
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
@@ -190,7 +227,6 @@ def create_settings_menus(default_preset):
with gr.Box(): with gr.Box():
gr.Markdown('Contrastive search') gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
with gr.Box(): with gr.Box():
gr.Markdown('Beam search (uses a lot of VRAM)') gr.Markdown('Beam search (uses a lot of VRAM)')
with gr.Row(): with gr.Row():
@@ -200,10 +236,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
with gr.Accordion('Soft prompt', open=False): with gr.Accordion('Soft prompt', open=False):
with gr.Row(): with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@@ -213,17 +245,15 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"] modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args) cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
#int_list = [k for k in cmd_list if type(k) is int]
shared.args.extensions = extensions shared.args.extensions = extensions
for k in modes[1:]: for k in modes[1:]:
@@ -238,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
shared.need_restart = True shared.need_restart = True
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
available_characters = get_available_characters() available_characters = get_available_characters()
@@ -286,8 +317,8 @@ else:
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]) default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title = 'Text generation web UI' title = 'Text generation web UI'
def create_interface():
def create_interface():
gen_events = [] gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions() extensions_module.load_extensions()
@@ -296,10 +327,7 @@ def create_interface():
if shared.is_chat(): if shared.is_chat():
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
if shared.args.cai_chat: shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2']))
else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row(): with gr.Row():
shared.gradio['Generate'] = gr.Button('Generate') shared.gradio['Generate'] = gr.Button('Generate')
@@ -316,13 +344,17 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
with gr.Row(): with gr.Row():
with gr.Column(scale=8): with gr.Column(scale=8):
shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
with gr.Column(scale=1): with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil") shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
@@ -363,35 +395,35 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column(): with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset) create_settings_menus(default_preset)
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
def set_chat_input(textbox): def set_chat_input(textbox):
return textbox, "" return textbox, ""
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation # Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display']) shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
@@ -403,19 +435,21 @@ def create_interface():
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2']], shared.gradio['display']) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
@@ -445,9 +479,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@@ -478,14 +512,17 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Model", elem_id="model-tab"):
create_model_menus()
with gr.Tab("Training", elem_id="training-tab"): with gr.Tab("Training", elem_id="training-tab"):
training.create_train_interface() training.create_train_interface()
@@ -499,7 +536,6 @@ def create_interface():
cmd_list = vars(shared.args) cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
bool_active = [k for k in bool_list if vars(shared.args)[k]] bool_active = [k for k in bool_list if vars(shared.args)[k]]
#int_list = [k for k in cmd_list if type(k) is int]
gr.Markdown("*Experimental*") gr.Markdown("*Experimental*")
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode") shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
@@ -513,6 +549,21 @@ def create_interface():
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat():
api.create_apis()
# Authentication # Authentication
auth = None auth = None
if shared.args.gradio_auth_path is not None: if shared.args.gradio_auth_path is not None:
@@ -529,6 +580,7 @@ def create_interface():
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
create_interface() create_interface()
while True: while True: