56 Commits

Author SHA1 Message Date
oobabooga
98dcfb8e12 Minor fixes 2023-04-11 11:24:55 -03:00
oobabooga
d0a6b53b4b Clean up 2023-04-11 11:11:10 -03:00
oobabooga
0f6cbe6e95 Refactor the UI 2023-04-11 10:49:26 -03:00
oobabooga
64f5c90ee7 Fix the API extension 2023-04-10 20:14:38 -03:00
oobabooga
58b34c0841 Fix chat_prompt_size 2023-04-10 20:06:42 -03:00
oobabooga
5234071c04 Improve Instruct mode text readability 2023-04-10 17:41:07 -03:00
IggoOnCode
09d8119e3c Add CPU LoRA training (#938)
(It's very slow)
2023-04-10 17:29:00 -03:00
Alex "mcmonkey" Goodwin
0caf718a21 add on-page documentation to parameters (#1008) 2023-04-10 17:19:12 -03:00
oobabooga
85a7954823 Update settings-template.json 2023-04-10 16:53:07 -03:00
oobabooga
d37b4f76b1 Merge branch 'main' of github.com:oobabooga/text-generation-webui 2023-04-10 16:45:09 -03:00
oobabooga
bd04ff27ad Make the bos token optional 2023-04-10 16:44:22 -03:00
oobabooga
f035b01823 Update README.md 2023-04-10 16:20:23 -03:00
Jeff Lefebvre
b7ca89ba3f Mention that build-essential is required (#1013) 2023-04-10 16:19:10 -03:00
loeken
52339e9b20 add make/g++ to docker (#1015) 2023-04-10 16:18:07 -03:00
oobabooga
4961f43702 Improve header bar colors 2023-04-10 16:15:16 -03:00
oobabooga
617530296e Instruct mode color/style improvements 2023-04-10 16:04:21 -03:00
oobabooga
0f1627eff1 Don't treat Intruct mode histories as regular histories
* They must now be saved/loaded manually
* Also improved browser caching of pfps
* Also changed the global default preset
2023-04-10 15:48:07 -03:00
oobabooga
d679c4be13 Change a label 2023-04-10 11:44:37 -03:00
oobabooga
45244ed125 More descriptive download info 2023-04-10 11:42:12 -03:00
oobabooga
7e70741a4e Download models from Model tab (#954 from UsamaKenway/main) 2023-04-10 11:38:30 -03:00
oobabooga
11b23db8d4 Remove unused imports 2023-04-10 11:37:42 -03:00
oobabooga
2c14df81a8 Use download-model.py to download the model 2023-04-10 11:36:39 -03:00
oobabooga
c6e9ba20a4 Merge branch 'main' into UsamaKenway-main 2023-04-10 11:14:03 -03:00
oobabooga
843f672227 fix random seeds to actually randomize (#1004 from mcmonkey4eva/seed-fix) 2023-04-10 10:56:12 -03:00
oobabooga
769aa900ea Print the used seed 2023-04-10 10:53:31 -03:00
oobabooga
32d078487e Add llama-cpp-python to requirements.txt 2023-04-10 10:45:51 -03:00
Alex "mcmonkey" Goodwin
30befe492a fix random seeds to actually randomize
Without this fix, manual seeds get locked in.
2023-04-10 06:29:10 -07:00
oobabooga
1911504f82 Minor bug fix 2023-04-09 23:45:41 -03:00
BlueprintCoding
8178fde2cb Added dropdown to character bias. (#986) 2023-04-09 23:44:31 -03:00
oobabooga
dba2000d2b Do things that I am not proud of 2023-04-09 23:40:49 -03:00
oobabooga
65552d2157 Merge branch 'main' of github.com:oobabooga/text-generation-webui 2023-04-09 23:19:53 -03:00
oobabooga
8c6155251a More robust 4-bit model loading 2023-04-09 23:19:28 -03:00
MarkovInequality
992663fa20 Added xformers support to Llama (#950) 2023-04-09 23:08:40 -03:00
Brian O'Connor
625d81f495 Update character log logic (#977)
* When logs are cleared, save the cleared log over the old log files
* Generate a log file when a character is loaded the first time
2023-04-09 22:20:21 -03:00
oobabooga
57f768eaad Better preset in api-example.py 2023-04-09 22:18:40 -03:00
oobabooga
a3085dba07 Fix LlamaTokenizer eos_token (attempt) 2023-04-09 21:19:39 -03:00
oobabooga
120f5662cf Better handle spaces for Continue 2023-04-09 20:37:31 -03:00
oobabooga
b27d757fd1 Minor change 2023-04-09 20:06:20 -03:00
oobabooga
d29f4624e9 Add a Continue button to chat mode 2023-04-09 20:04:16 -03:00
oobabooga
170e0c05c4 Typo 2023-04-09 17:00:59 -03:00
oobabooga
34ec02d41d Make download-model.py importable 2023-04-09 16:59:59 -03:00
oobabooga
f91d3a3ff4 server.py readability 2023-04-09 14:46:32 -03:00
Usama Kenway
ebdf4c8c12 path fixed 2023-04-09 16:53:21 +05:00
Usama Kenway
7436dd5b4a download custom model menu (from hugging face) added in model tab 2023-04-09 16:11:43 +05:00
oobabooga
bce1b7fbb2 Update README.md 2023-04-09 02:19:40 -03:00
oobabooga
f7860ce192 Update README.md 2023-04-09 02:19:17 -03:00
oobabooga
ece8ed2c84 Update README.md 2023-04-09 02:18:42 -03:00
oobabooga
cc693a7546 Remove obsolete code 2023-04-09 00:51:07 -03:00
oobabooga
2fde50a800 Delete docker.md 2023-04-08 22:37:54 -03:00
loeken
acc235aced updated docs for docker, setup video added, removed left over GPTQ_VERSION from docker-compose (#940) 2023-04-08 22:35:15 -03:00
Blake Wyatt
df561fd896 Fix ggml downloading in download-model.py (#915) 2023-04-08 18:52:30 -03:00
oobabooga
d272ac46dd Add Pillow as a requirement 2023-04-08 18:48:46 -03:00
oobabooga
cb169d0834 Minor formatting changes 2023-04-08 17:34:07 -03:00
oobabooga
2f16d0afca Remove redundant events 2023-04-08 17:32:36 -03:00
oobabooga
a6a00cb82f Properly concatenate chat events 2023-04-08 17:25:21 -03:00
Φφ
c97c270040 Send_pictures small fix (#546) 2023-04-08 01:55:16 -03:00
22 changed files with 726 additions and 323 deletions

View File

@@ -26,7 +26,7 @@ LABEL maintainer="Your Name <your.email@example.com>"
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI" LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
RUN apt-get update && \ RUN apt-get update && \
apt-get install --no-install-recommends -y git python3 python3-pip && \ apt-get install --no-install-recommends -y git python3 python3-pip make g++ && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv

View File

@@ -1,11 +1,9 @@
# Text generation web UI # Text generation web UI
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, OPT, and GALACTICA. A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation. Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
[[Try it on Google Colab]](https://colab.research.google.com/github/oobabooga/AI-Notebooks/blob/main/Colab-TextGen-GPU.ipynb)
|![Image1](https://github.com/oobabooga/screenshots/raw/main/qa.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/cai3.png) | |![Image1](https://github.com/oobabooga/screenshots/raw/main/qa.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/cai3.png) |
|:---:|:---:| |:---:|:---:|
|![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) | ![Image4](https://github.com/oobabooga/screenshots/raw/main/galactica.png) | |![Image3](https://github.com/oobabooga/screenshots/raw/main/gpt4chan.png) | ![Image4](https://github.com/oobabooga/screenshots/raw/main/galactica.png) |
@@ -34,7 +32,6 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs) * [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
* Softprompts * Softprompts
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions) * [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
## Installation ## Installation
@@ -73,9 +70,15 @@ On Linux or WSL, it can be automatically installed with these two commands:
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh" curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
bash Miniconda3.sh bash Miniconda3.sh
``` ```
Source: https://educe-ubc.github.io/conda.html Source: https://educe-ubc.github.io/conda.html
#### 0.1 (Ubuntu/WSL) Install build tools
```
sudo apt install build-essential
```
#### 1. Create a new conda environment #### 1. Create a new conda environment
``` ```
@@ -209,7 +212,7 @@ Optionally, you can use the following command-line flags:
| Flag | Description | | Flag | Description |
|---------------------------------------------|-------------| |---------------------------------------------|-------------|
| `--cpu` | Use the CPU to generate text. | | `--cpu` | Use the CPU to generate text. Warning: Training on CPU is extremely slow.|
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. | | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. |
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. | | `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.| | `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
@@ -218,6 +221,8 @@ Optionally, you can use the following command-line flags:
| `--load-in-8bit` | Load the model with 8-bit precision.| | `--load-in-8bit` | Load the model with 8-bit precision.|
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. | | `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
| `--sdp-attention` | Use torch 2.0's sdp attention. |
#### llama.cpp #### llama.cpp

View File

@@ -22,10 +22,10 @@ server = "127.0.0.1"
params = { params = {
'max_new_tokens': 200, 'max_new_tokens': 200,
'do_sample': True, 'do_sample': True,
'temperature': 0.5, 'temperature': 0.72,
'top_p': 0.9, 'top_p': 0.73,
'typical_p': 1, 'typical_p': 1,
'repetition_penalty': 1.05, 'repetition_penalty': 1.1,
'encoder_repetition_penalty': 1.0, 'encoder_repetition_penalty': 1.0,
'top_k': 0, 'top_k': 0,
'min_length': 0, 'min_length': 0,

View File

@@ -25,9 +25,7 @@
.message-body {} .message-body {}
.message-body p { .message-body p {
margin-bottom: 0 !important;
font-size: 15px !important; font-size: 15px !important;
line-height: 1.428571429 !important;
} }
.message-body li { .message-body li {
@@ -51,15 +49,16 @@
padding: 15px; padding: 15px;
border-radius: 20px; border-radius: 20px;
background-color: #0000000f; background-color: #0000000f;
margin-bottom: 17.5px; margin-top: 9px !important;
margin-bottom: 18px !important;
} }
.gradio-container .chat .user-message { .gradio-container .chat .user-message {
padding: 15px; padding: 15px;
border-radius: 20px; border-radius: 20px;
margin-bottom: 17.5px !important; margin-bottom: 9px !important;
} }
.dark .chat .assistant-message { .dark .chat .assistant-message {
background-color: #ffffff21; background-color: #374151;
} }

View File

@@ -67,3 +67,13 @@ span.math.inline {
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
flex-wrap: nowrap; flex-wrap: nowrap;
} }
.header_bar {
background-color: #f7f7f7;
margin-bottom: 40px;
}
.dark .header_bar {
border: none !important;
background-color: #8080802b;
}

View File

@@ -1,4 +1,4 @@
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px"; document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar");
document.getElementById("main").parentNode.style = "padding: 0; margin: 0"; document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"; document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";

View File

@@ -6,7 +6,6 @@ services:
args: args:
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus # specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST} TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
GPTQ_VERSION: ${GPTQ_VERSION}
WEBUI_VERSION: ${WEBUI_VERSION} WEBUI_VERSION: ${WEBUI_VERSION}
env_file: .env env_file: .env
ports: ports:

View File

@@ -19,50 +19,6 @@ import requests
import tqdm import tqdm
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
def get_file(url, output_folder):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not args.clean:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
else:
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name):
return branch_name
else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
def select_model_from_default_options(): def select_model_from_default_options():
models = { models = {
@@ -110,7 +66,20 @@ EleutherAI/pythia-1.4b-deduped
return model, branch return model, branch
def get_download_links_from_huggingface(model, branch): def sanitize_model_and_branch_names(model, branch):
if model[-1] == '/':
model = model[:-1]
if branch is None:
branch = "main"
else:
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if not pattern.match(branch):
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
return model, branch
def get_download_links_from_huggingface(model, branch, text_only=False):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b"" cursor = b""
@@ -142,14 +111,14 @@ def get_download_links_from_huggingface(model, branch):
is_tokenizer = re.match("tokenizer.*\.model", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)): if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
if 'lfs' in dict[i]: if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']]) sha256.append([fname, dict[i]['lfs']['oid']])
if is_text: if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text') classifications.append('text')
continue continue
if not args.text_only: if not text_only:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
if is_safetensors: if is_safetensors:
has_safetensors = True has_safetensors = True
@@ -177,80 +146,125 @@ def get_download_links_from_huggingface(model, branch):
return links, sha256, is_lora return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8): def get_output_folder(model, branch, is_lora, base_folder=None):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) if base_folder is None:
if __name__ == '__main__':
model = args.MODEL
branch = args.branch
if model is None:
model, branch = select_model_from_default_options()
else:
if model[-1] == '/':
model = model[:-1]
branch = args.branch
if branch is None:
branch = "main"
else:
try:
branch = sanitize_branch_name(branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
if args.output is not None:
base_folder = args.output
else:
base_folder = 'models' if not is_lora else 'loras' base_folder = 'models' if not is_lora else 'loras'
output_folder = f"{'_'.join(model.split('/')[-2:])}" output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main': if branch != 'main':
output_folder += f'_{branch}' output_folder += f'_{branch}'
output_folder = Path(base_folder) / output_folder output_folder = Path(base_folder) / output_folder
return output_folder
def get_single_file(url, output_folder, start_from_scratch=False):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
else:
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
# Creating the folder and writing the metadata
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
def check_model_files(model, branch, links, sha256, output_folder):
# Validate the checksums
validated = True
for i in range(len(sha256)):
fpath = (output_folder / sha256[i][0])
if not fpath.exists():
print(f"The following file is missing: {fpath}")
validated = False
continue
with open(output_folder / sha256[i][0], "rb") as f:
bytes = f.read()
file_hash = hashlib.sha256(bytes).hexdigest()
if file_hash != sha256[i][1]:
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
validated = False
else:
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
if validated:
print('[+] Validated checksums of all model files!')
else:
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
branch = args.branch
model = args.MODEL
if model is None:
model, branch = select_model_from_default_options()
# Cleaning up the model/branch names
try:
model, branch = sanitize_model_and_branch_names(model, branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
# Getting the download links from Hugging Face
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
# Getting the output folder
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
if args.check: if args.check:
# Validate the checksums # Check previously downloaded files
validated = True check_model_files(model, branch, links, sha256, output_folder)
for i in range(len(sha256)):
fpath = (output_folder / sha256[i][0])
if not fpath.exists():
print(f"The following file is missing: {fpath}")
validated = False
continue
with open(output_folder / sha256[i][0], "rb") as f:
bytes = f.read()
file_hash = hashlib.sha256(bytes).hexdigest()
if file_hash != sha256[i][1]:
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
validated = False
else:
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
if validated:
print('[+] Validated checksums of all model files!')
else:
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
else: else:
# Download files
# Creating the folder and writing the metadata download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
if not output_folder.exists():
output_folder.mkdir()
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
f.write(f'url: https://huggingface.co/{model}\n')
f.write(f'branch: {branch}\n')
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
sha256_str = ''
for i in range(len(sha256)):
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
if sha256_str != '':
f.write(f'sha256sum:\n{sha256_str}')
# Downloading the files
print(f"Downloading the model to {output_folder}")
download_files(links, output_folder, args.threads)

View File

@@ -57,6 +57,7 @@ class Handler(BaseHTTPRequestHandler):
'length_penalty': float(body.get('length_penalty', 1)), 'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)), 'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)), 'seed': int(body.get('seed', -1)),
'add_bos_token': int(body.get('add_bos_token', True)),
} }
generator = generate_reply( generator = generate_reply(

View File

@@ -1,8 +1,23 @@
import gradio as gr import gradio as gr
import os
# get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))
# check if the bias_options.txt file exists, if not, create it
bias_file = os.path.join(current_dir, "bias_options.txt")
if not os.path.isfile(bias_file):
with open(bias_file, "w") as f:
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
# read bias options from the text file
with open(bias_file, "r") as f:
bias_options = [line.strip() for line in f.readlines()]
params = { params = {
"activate": True, "activate": True,
"bias string": " *I am so happy*", "bias string": " *I am so happy*",
"use custom string": False,
} }
@@ -11,7 +26,6 @@ def input_modifier(string):
This function is applied to your text inputs before This function is applied to your text inputs before
they are fed into the model. they are fed into the model.
""" """
return string return string
@@ -19,7 +33,6 @@ def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
return string return string
@@ -29,9 +42,11 @@ def bot_prefix_modifier(string):
the prefix text for the Bot and can be used to bias its the prefix text for the Bot and can be used to bias its
behavior. behavior.
""" """
if params['activate']: if params['activate']:
return f'{string} {params["bias string"].strip()} ' if params['use custom string']:
return f'{string} {params["custom string"].strip()} '
else:
return f'{string} {params["bias string"].strip()} '
else: else:
return string return string
@@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias') activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
string = gr.Textbox(value=params["bias string"], label='Character bias') dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
string.change(lambda x: params.update({"bias string": x}), string, None) def update_bias_string(x):
if x:
params.update({"bias string": x})
else:
params.update({"bias string": dropdown_string.get()})
return x
def update_custom_string(x):
params.update({"custom string": x})
dropdown_string.change(update_bias_string, dropdown_string, None)
custom_string.change(update_custom_string, custom_string, None)
activate.change(lambda x: params.update({"activate": x}), activate, None) activate.change(lambda x: params.update({"activate": x}), activate, None)
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
# Group elements together depending on the selected option
def bias_string_group():
if use_custom_string.value:
return gr.Group([use_custom_string, custom_string])
else:
return dropdown_string

View File

@@ -25,7 +25,7 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' text = f'*{name1} sends {name2} a picture that contains the following: {caption_image(picture)}*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
picture.thumbnail((300, 300)) picture.thumbnail((300, 300))
buffer = BytesIO() buffer = BytesIO()

View File

@@ -100,10 +100,10 @@ def load_quantized(model_name):
found_safetensors = list(path_to_model.glob("*.safetensors")) found_safetensors = list(path_to_model.glob("*.safetensors"))
pt_path = None pt_path = None
if len(found_pts) == 1: if len(found_pts) > 0:
pt_path = found_pts[0] pt_path = found_pts[-1]
elif len(found_safetensors) == 1: elif len(found_safetensors) > 0:
pt_path = found_safetensors[0] pt_path = found_safetensors[-1]
else: else:
if path_to_model.name.lower().startswith('llama-7b'): if path_to_model.name.lower().startswith('llama-7b'):
pt_model = f'llama-7b-{shared.args.wbits}bit' pt_model = f'llama-7b-{shared.args.wbits}bit'
@@ -119,13 +119,14 @@ def load_quantized(model_name):
# Try to find the .safetensors or .pt both in the model dir and in the subfolder # Try to find the .safetensors or .pt both in the model dir and in the subfolder
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
if path.exists(): if path.exists():
print(f"Found {path}")
pt_path = path pt_path = path
break break
if not pt_path: if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...") print("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit() exit()
else:
print(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload # qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer: if model_type == 'llama' and shared.args.pre_layer:

View File

@@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
@@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
i = len(shared.history['internal']) - 1 i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
string = shared.history['internal'][i][0] string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
@@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
if impersonate: if impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2 limit = 2
elif _continue:
limit = 3
else: else:
# Adding the user message # Adding the user message
user_input = fix_newlines(user_input) user_input = fix_newlines(user_input)
@@ -68,16 +74,16 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
return prompt return prompt
def extract_message_from_reply(reply, name1, name2, stop_at_newline): def extract_message_from_reply(reply, state):
next_character_found = False next_character_found = False
if stop_at_newline: if state['stop_at_newline']:
lines = reply.split('\n') lines = reply.split('\n')
reply = lines[0].strip() reply = lines[0].strip()
if len(lines) > 1: if len(lines) > 1:
next_character_found = True next_character_found = True
else: else:
for string in [f"\n{name1}:", f"\n{name2}:"]: for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
idx = reply.find(string) idx = reply.find(string)
if idx != -1: if idx != -1:
reply = reply[:idx] reply = reply[:idx]
@@ -86,7 +92,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
# If something like "\nYo" is generated just before "\nYou:" # If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it # is completed, trim it
if not next_character_found: if not next_character_found:
for string in [f"\n{name1}:", f"\n{name2}:"]: for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
for j in range(len(string) - 1, 0, -1): for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]: if reply[-j:] == string[:j]:
reply = reply[:-j] reply = reply[:-j]
@@ -99,20 +105,18 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if mode == 'instruct': if state['mode'] == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else: else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"] stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
# Defining some variables # Defining some variables
cumulative_reply = '' cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True just_started = True
name1_original = name1
visible_text = custom_generate_chat_prompt = None visible_text = custom_generate_chat_prompt = None
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
# Check if any extension wants to hijack this function call # Check if any extension wants to hijack this function call
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
@@ -124,28 +128,33 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
text = apply_extensions(text, "input") if not _continue:
text = apply_extensions(text, "input")
# Generating the prompt # Generating the prompt
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} kwargs = {
'end_of_turn': state['end_of_turn'],
'is_instruct': state['mode'] == 'instruct',
'_continue': _continue
}
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
else: else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
# Yield *Is typing...* # Yield *Is typing...*
if not regenerate: if not any((regenerate, _continue)):
yield shared.history['visible'] + [[visible_text, shared.processing_message]] yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate # Generate
for i in range(generate_state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
@@ -154,11 +163,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
return shared.history['visible'] return shared.history['visible']
if just_started: if just_started:
just_started = False just_started = False
shared.history['internal'].append(['', '']) if not _continue:
shared.history['visible'].append(['', '']) shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', ''])
shared.history['internal'][-1] = [text, reply] if _continue:
shared.history['visible'][-1] = [visible_text, visible_reply] sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
else:
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
if not shared.args.no_stream: if not shared.args.no_stream:
yield shared.history['visible'] yield shared.history['visible']
if next_character_found: if next_character_found:
@@ -170,28 +185,25 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def impersonate_wrapper(text, state):
if mode == 'instruct': if state['mode'] == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else: else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"] stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
# Defining some variables # Defining some variables
cumulative_reply = '' cumulative_reply = ''
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower(): prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
name1 = "You"
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
for i in range(generate_state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) reply, next_character_found = extract_message_from_reply(reply, state)
yield reply yield reply
if next_character_found: if next_character_found:
break break
@@ -202,22 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def cai_chatbot_wrapper(text, state):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): for history in chatbot_wrapper(text, state):
yield chat_html_wrapper(history, name1, name2, mode) yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def regenerate_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else: else:
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*' # Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True): for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def continue_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else:
# Yield ' ...'
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):
@@ -257,6 +279,9 @@ def clear_chat_log(name1, name2, greeting, mode):
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
# Save cleared logs
save_history(mode)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@@ -301,15 +326,23 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
return history return history
def save_history(timestamp=True): def save_history(mode, timestamp=False):
if timestamp: # Instruct mode histories should not be saved as if
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" # Alpaca or Vicuna were characters
if mode == 'instruct':
if not timestamp:
return
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
fname = f"{shared.character}_persistent.json" if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
fname = f"{shared.character}_persistent.json"
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
@@ -323,16 +356,6 @@ def load_history(file, name1, name2):
shared.history['visible'] = j['data_visible'] shared.history['visible'] = j['data_visible']
else: else:
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
# Compatibility with Pygmalion AI's official web UI
elif 'chat' in j:
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
shared.history['visible'][0][0] = ''
else:
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
except: except:
shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
@@ -368,8 +391,6 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, mode): def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
shared.history['internal'] = []
shared.history['visible'] = []
context = greeting = end_of_turn = "" context = greeting = end_of_turn = ""
greeting_field = 'greeting' greeting_field = 'greeting'
picture = None picture = None
@@ -414,13 +435,22 @@ def load_character(character, name1, name2, mode):
greeting = shared.settings['greeting'] greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn'] end_of_turn = shared.settings['end_of_turn']
if Path(f'logs/{shared.character}_persistent.json').exists(): if mode != 'instruct':
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) shared.history['internal'] = []
elif greeting != "": shared.history['visible'] = []
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
else:
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
# Create .json log files since they don't already exist
save_history(mode)
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def load_default_history(name1, name2): def load_default_history(name1, name2):

View File

@@ -164,10 +164,9 @@ def generate_instruct_html(history):
def generate_cai_chat_html(history, name1, name2, reset_cache=False): def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
# The time.time() is to prevent the brower from caching the image # We use ?name2 and ?time.time() to force the browser to reset caches
suffix = f"?{time.time()}" if reset_cache else f"?{name2}" img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else ''
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else '' img_me = f'<img src="file/cache/pfp_me.png?{time.time() if reset_cache else ""}">' if Path("cache/pfp_me.png").exists() else ''
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
for i, _row in enumerate(history[::-1]): for i, _row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row] row = [convert_to_markdown(entry) for entry in _row]

View File

@@ -0,0 +1,176 @@
import math
import sys
import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama
from typing import Optional
from typing import Tuple
import modules.shared as shared
if shared.args.xformers:
try:
import xformers.ops
except Exception:
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention():
if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
print("Replaced attention with xformers_attention")
elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
print("Replaced attention with sdp_attention")
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
#We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
dtype = query_states.dtype
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
#We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

View File

@@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer) BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared import modules.shared as shared
from modules import llama_attn_hijack
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@@ -169,11 +170,23 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Hijack attention with xformers
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
# Loading the tokenizer # Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM: elif type(model) is transformers.LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
# Leaving this here until the LLaMA tokenizer gets figured out.
# For some people this fixes things, for others it causes an error.
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except:
pass
else: else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
tokenizer.truncation_side = 'left' tokenizer.truncation_side = 'left'

View File

@@ -35,6 +35,7 @@ settings = {
'greeting': 'Hello there!', 'greeting': 'Hello there!',
'end_of_turn': '', 'end_of_turn': '',
'stop_at_newline': False, 'stop_at_newline': False,
'add_bos_token': True,
'chat_prompt_size': 2048, 'chat_prompt_size': 2048,
'chat_prompt_size_min': 0, 'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048, 'chat_prompt_size_max': 2048,
@@ -44,7 +45,7 @@ settings = {
'default_extensions': [], 'default_extensions': [],
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {
'default': 'NovelAI-Sphinx Moth', 'default': 'Default',
'.*(alpaca|llama)': "LLaMA-Precise", '.*(alpaca|llama)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter', '.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive', '.*RWKV': 'Naive',
@@ -89,7 +90,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
# Accelerate/transformers # Accelerate/transformers
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.') parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.') parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
@@ -98,6 +99,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.') parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
# llama.cpp # llama.cpp
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.') parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')

View File

@@ -1,3 +1,4 @@
import random
import re import re
import time import time
import traceback import traceback
@@ -21,7 +22,7 @@ def get_max_prompt_length(tokens):
return max_length return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids)) input_ids = np.array(input_ids).reshape(1, len(input_ids))
@@ -29,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else: else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
# This is a hack for making replies more creative.
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Llama adds this extra token when the first character is '\n', and this
# compromises the stopping criteria, so we just remove it
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:] input_ids = input_ids[:, 1:]
@@ -62,9 +69,8 @@ def generate_softprompt_input_tensors(input_ids):
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@@ -72,9 +78,8 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s) s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s return s
# Fix the LaTeX equations in galactica # Fix the LaTeX equations in galactica
def fix_galactica(s): def fix_galactica(s):
s = s.replace(r'\[', r'$') s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$') s = s.replace(r'\]', r'$')
@@ -101,19 +106,22 @@ def formatted_outputs(reply, model_name):
def set_manual_seed(seed): def set_manual_seed(seed):
if seed != -1: seed = int(seed)
torch.manual_seed(seed) if seed == -1:
if torch.cuda.is_available(): seed = random.randint(1, 2**31)
torch.cuda.manual_seed_all(seed) torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
return seed
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): def generate_reply(question, state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(generate_state['seed']) seed = set_manual_seed(state['seed'])
shared.stop_everything = False shared.stop_everything = False
generate_params = {} generate_params = {}
t0 = time.time() t0 = time.time()
@@ -128,8 +136,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# separately and terminate the function call earlier # separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k] generate_params[k] = state[k]
generate_params['token_count'] = generate_state['max_new_tokens'] generate_params['token_count'] = state['max_new_tokens']
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(context=question, **generate_params)
@@ -155,10 +163,10 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
t1 = time.time() t1 = time.time()
original_tokens = len(encode(original_question)[0]) original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return
input_ids = encode(question, generate_state['max_new_tokens']) input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token'])
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
@@ -173,15 +181,13 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
if not shared.args.flexgen: if not shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = generate_state[k] generate_params[k] = state[k]
generate_params['eos_token_id'] = eos_token_ids generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list generate_params['stopping_criteria'] = stopping_criteria_list
if shared.args.no_stream:
generate_params['min_length'] = 0
else: else:
for k in ['max_new_tokens', 'do_sample', 'temperature']: for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = generate_state[k] generate_params[k] = state[k]
generate_params['stop'] = generate_state['eos_token_ids'][-1] generate_params['stop'] = state['eos_token_ids'][-1]
if not shared.args.no_stream: if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8 generate_params['max_new_tokens'] = 8
@@ -244,7 +250,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
for i in range(generate_state['max_new_tokens'] // 8 + 1): for i in range(state['max_new_tokens'] // 8 + 1):
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = shared.model.generate(**generate_params)[0] output = shared.model.generate(**generate_params)[0]
@@ -276,5 +282,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
t1 = time.time() t1 = time.time()
original_tokens = len(original_input_ids[0]) original_tokens = len(original_input_ids[0])
new_tokens = len(output) - original_tokens new_tokens = len(output) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return

View File

@@ -238,7 +238,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
warmup_steps=100, warmup_steps=100,
num_train_epochs=epochs, num_train_epochs=epochs,
learning_rate=actual_lr, learning_rate=actual_lr,
fp16=True, fp16=False if shared.args.cpu else True,
logging_steps=20, logging_steps=20,
evaluation_strategy="steps" if eval_data is not None else "no", evaluation_strategy="steps" if eval_data is not None else "no",
save_strategy="steps", save_strategy="steps",
@@ -248,7 +248,8 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
save_total_limit=3, save_total_limit=3,
load_best_model_at_end=True if eval_data is not None else False, load_best_model_at_end=True if eval_data is not None else False,
# TODO: Enable multi-device support # TODO: Enable multi-device support
ddp_find_unused_parameters=None ddp_find_unused_parameters=None,
no_cuda=shared.args.cpu
), ),
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
callbacks=list([Callbacks()]) callbacks=list([Callbacks()])

View File

@@ -1,10 +1,10 @@
accelerate==0.18.0 accelerate==0.18.0
bitsandbytes==0.37.2
datasets datasets
flexgen==0.1.7 flexgen==0.1.7
gradio==3.24.1 gradio==3.24.1
markdown markdown
numpy numpy
Pillow>=9.5.0
peft==0.2.0 peft==0.2.0
requests requests
rwkv==0.7.3 rwkv==0.7.3
@@ -13,3 +13,6 @@ sentencepiece
pyyaml pyyaml
tqdm tqdm
git+https://github.com/huggingface/transformers git+https://github.com/huggingface/transformers
bitsandbytes==0.37.2; platform_system != "Windows"
llama-cpp-python==0.1.30; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"

262
server.py
View File

@@ -2,11 +2,14 @@ import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import importlib
import io import io
import json import json
import os
import re import re
import sys import sys
import time import time
import traceback
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -21,6 +24,7 @@ from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import generate_reply, stop_everything_event from modules.text_generation import generate_reply, stop_everything_event
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@@ -172,6 +176,34 @@ def create_prompt_menus():
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def download_model_wrapper(repo_id):
try:
downloader = importlib.import_module("download-model")
model = repo_id
branch = "main"
check = False
yield ("Cleaning up the model/branch names")
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
yield ("Getting the download links from Hugging Face")
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
yield ("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora)
if check:
yield ("Checking previously downloaded files")
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
yield (f"Downloading files to {output_folder}")
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
yield ("Done!")
except:
yield traceback.format_exc()
def create_model_menus(): def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -182,14 +214,26 @@ def create_model_menus():
with gr.Row(): with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
with gr.Column():
shared.gradio['download_button'] = gr.Button("Download")
shared.gradio['download_status'] = gr.Markdown()
with gr.Column():
pass
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']:
generate_params[k] = shared.settings[k] generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params) shared.gradio['generate_state'] = gr.State(generate_params)
@@ -204,18 +248,18 @@ def create_settings_menus(default_preset):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
with gr.Box(): with gr.Box():
gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))') gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
with gr.Column(): with gr.Column():
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
with gr.Column(): with gr.Column():
with gr.Box(): with gr.Box():
@@ -229,6 +273,7 @@ def create_settings_menus(default_preset):
with gr.Column(): with gr.Column():
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
with gr.Accordion('Soft prompt', open=False): with gr.Accordion('Soft prompt', open=False):
with gr.Row(): with gr.Row():
@@ -312,6 +357,20 @@ else:
title = 'Text generation web UI' title = 'Text generation web UI'
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
return elements
def gather_interface_values(*args):
output = {}
for i, element in enumerate(shared.input_elements):
output[element] = args[i]
return output
def create_interface(): def create_interface():
gen_events = [] gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
@@ -319,7 +378,11 @@ def create_interface():
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.is_chat(): if shared.is_chat():
shared.input_elements = list_interface_input_elements(chat=True)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
@@ -327,8 +390,9 @@ def create_interface():
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
with gr.Row(): with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate') shared.gradio['Regenerate'] = gr.Button('Regenerate')
shared.gradio['Continue'] = gr.Button('Continue')
shared.gradio['Impersonate'] = gr.Button('Impersonate')
with gr.Row(): with gr.Row():
shared.gradio['Copy last reply'] = gr.Button('Copy last reply') shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply') shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
@@ -338,8 +402,8 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode") shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False) shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
with gr.Row(): with gr.Row():
@@ -386,66 +450,92 @@ def create_interface():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column(): with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']] shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
def set_chat_input(textbox):
return textbox, ""
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Regenerate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Impersonate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
)
shared.gradio['Replace last reply'].click(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Clear history-confirm'].click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Stop'].click(
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['mode'].change(
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Instruction templates'].change(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
# Clearing stuff and saving the history
for i in ['Generate', 'Regenerate', 'Replace last reply']:
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display']) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True) shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
@@ -475,12 +565,23 @@ def create_interface():
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else: else:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -508,10 +609,23 @@ def create_interface():
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Model", elem_id="model-tab"): with gr.Tab("Model", elem_id="model-tab"):
@@ -537,24 +651,14 @@ def create_interface():
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags") shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface") shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None) # Reset interface event
shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}') shared.gradio['reset_interface'].click(
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat(): if not shared.is_chat():
api.create_apis() api.create_apis()

View File

@@ -7,7 +7,9 @@
"name2": "Assistant", "name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"greeting": "Hello there!", "greeting": "Hello there!",
"end_of_turn": "",
"stop_at_newline": false, "stop_at_newline": false,
"add_bos_token": true,
"chat_prompt_size": 2048, "chat_prompt_size": 2048,
"chat_prompt_size_min": 0, "chat_prompt_size_min": 0,
"chat_prompt_size_max": 2048, "chat_prompt_size_max": 2048,
@@ -19,7 +21,8 @@
"gallery" "gallery"
], ],
"presets": { "presets": {
"default": "NovelAI-Sphinx Moth", "default": "Default",
".*(alpaca|llama)": "LLaMA-Precise",
".*pygmalion": "NovelAI-Storywriter", ".*pygmalion": "NovelAI-Storywriter",
".*RWKV": "Naive" ".*RWKV": "Naive"
}, },