Compare commits
52 Commits
concatenat
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39fa6e57cc | ||
|
|
5234071c04 | ||
|
|
09d8119e3c | ||
|
|
0caf718a21 | ||
|
|
85a7954823 | ||
|
|
d37b4f76b1 | ||
|
|
bd04ff27ad | ||
|
|
f035b01823 | ||
|
|
b7ca89ba3f | ||
|
|
52339e9b20 | ||
|
|
4961f43702 | ||
|
|
617530296e | ||
|
|
0f1627eff1 | ||
|
|
d679c4be13 | ||
|
|
45244ed125 | ||
|
|
7e70741a4e | ||
|
|
11b23db8d4 | ||
|
|
2c14df81a8 | ||
|
|
c6e9ba20a4 | ||
|
|
843f672227 | ||
|
|
769aa900ea | ||
|
|
32d078487e | ||
|
|
30befe492a | ||
|
|
1911504f82 | ||
|
|
8178fde2cb | ||
|
|
dba2000d2b | ||
|
|
65552d2157 | ||
|
|
8c6155251a | ||
|
|
992663fa20 | ||
|
|
625d81f495 | ||
|
|
57f768eaad | ||
|
|
a3085dba07 | ||
|
|
120f5662cf | ||
|
|
b27d757fd1 | ||
|
|
d29f4624e9 | ||
|
|
170e0c05c4 | ||
|
|
34ec02d41d | ||
|
|
f91d3a3ff4 | ||
|
|
ebdf4c8c12 | ||
|
|
7436dd5b4a | ||
|
|
bce1b7fbb2 | ||
|
|
f7860ce192 | ||
|
|
ece8ed2c84 | ||
|
|
cc693a7546 | ||
|
|
2fde50a800 | ||
|
|
acc235aced | ||
|
|
df561fd896 | ||
|
|
d272ac46dd | ||
|
|
cb169d0834 | ||
|
|
2f16d0afca | ||
|
|
a6a00cb82f | ||
|
|
c97c270040 |
@@ -26,7 +26,7 @@ LABEL maintainer="Your Name <your.email@example.com>"
|
||||
LABEL description="Docker image for GPTQ-for-LLaMa and Text Generation WebUI"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install --no-install-recommends -y git python3 python3-pip && \
|
||||
apt-get install --no-install-recommends -y git python3 python3-pip make g++ && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
|
||||
|
||||
17
README.md
17
README.md
@@ -1,11 +1,9 @@
|
||||
# Text generation web UI
|
||||
|
||||
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, OPT, and GALACTICA.
|
||||
A gradio web UI for running Large Language Models like LLaMA, llama.cpp, GPT-J, Pythia, OPT, and GALACTICA.
|
||||
|
||||
Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) of text generation.
|
||||
|
||||
[[Try it on Google Colab]](https://colab.research.google.com/github/oobabooga/AI-Notebooks/blob/main/Colab-TextGen-GPU.ipynb)
|
||||
|
||||
| |  |
|
||||
|:---:|:---:|
|
||||
| |  |
|
||||
@@ -34,7 +32,6 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
|
||||
* Softprompts
|
||||
* [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
|
||||
* [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -73,9 +70,15 @@ On Linux or WSL, it can be automatically installed with these two commands:
|
||||
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
|
||||
bash Miniconda3.sh
|
||||
```
|
||||
|
||||
Source: https://educe-ubc.github.io/conda.html
|
||||
|
||||
#### 0.1 (Ubuntu/WSL) Install build tools
|
||||
|
||||
```
|
||||
sudo apt install build-essential
|
||||
```
|
||||
|
||||
|
||||
#### 1. Create a new conda environment
|
||||
|
||||
```
|
||||
@@ -209,7 +212,7 @@ Optionally, you can use the following command-line flags:
|
||||
|
||||
| Flag | Description |
|
||||
|---------------------------------------------|-------------|
|
||||
| `--cpu` | Use the CPU to generate text. |
|
||||
| `--cpu` | Use the CPU to generate text. Warning: Training on CPU is extremely slow.|
|
||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. |
|
||||
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
||||
@@ -218,6 +221,8 @@ Optionally, you can use the following command-line flags:
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
||||
| `--xformers` | Use xformer's memory efficient attention. This should increase your tokens/s. |
|
||||
| `--sdp-attention` | Use torch 2.0's sdp attention. |
|
||||
|
||||
#### llama.cpp
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ server = "127.0.0.1"
|
||||
params = {
|
||||
'max_new_tokens': 200,
|
||||
'do_sample': True,
|
||||
'temperature': 0.5,
|
||||
'top_p': 0.9,
|
||||
'temperature': 0.72,
|
||||
'top_p': 0.73,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.05,
|
||||
'repetition_penalty': 1.1,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
|
||||
@@ -25,9 +25,7 @@
|
||||
.message-body {}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
font-size: 15px !important;
|
||||
line-height: 1.428571429 !important;
|
||||
}
|
||||
|
||||
.message-body li {
|
||||
@@ -51,15 +49,16 @@
|
||||
padding: 15px;
|
||||
border-radius: 20px;
|
||||
background-color: #0000000f;
|
||||
margin-bottom: 17.5px;
|
||||
margin-top: 9px !important;
|
||||
margin-bottom: 18px !important;
|
||||
}
|
||||
|
||||
.gradio-container .chat .user-message {
|
||||
padding: 15px;
|
||||
border-radius: 20px;
|
||||
margin-bottom: 17.5px !important;
|
||||
margin-bottom: 9px !important;
|
||||
}
|
||||
|
||||
.dark .chat .assistant-message {
|
||||
background-color: #ffffff21;
|
||||
background-color: #374151;
|
||||
}
|
||||
10
css/main.css
10
css/main.css
@@ -67,3 +67,13 @@ span.math.inline {
|
||||
div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
|
||||
flex-wrap: nowrap;
|
||||
}
|
||||
|
||||
.header_bar {
|
||||
background-color: #f7f7f7;
|
||||
margin-bottom: 40px;
|
||||
}
|
||||
|
||||
.dark .header_bar {
|
||||
border: none !important;
|
||||
background-color: #8080802b;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
|
||||
document.getElementById("main").parentNode.childNodes[0].classList.add("header_bar");
|
||||
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
|
||||
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ services:
|
||||
args:
|
||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
|
||||
GPTQ_VERSION: ${GPTQ_VERSION}
|
||||
WEBUI_VERSION: ${WEBUI_VERSION}
|
||||
env_file: .env
|
||||
ports:
|
||||
|
||||
@@ -19,50 +19,6 @@ import requests
|
||||
import tqdm
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_file(url, output_folder):
|
||||
filename = Path(url.rsplit('/', 1)[1])
|
||||
output_path = output_folder / filename
|
||||
if output_path.exists() and not args.clean:
|
||||
# Check if the file has already been downloaded completely
|
||||
r = requests.get(url, stream=True)
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
if output_path.stat().st_size >= total_size:
|
||||
return
|
||||
# Otherwise, resume the download from where it left off
|
||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||
mode = 'ab'
|
||||
else:
|
||||
headers = {}
|
||||
mode = 'wb'
|
||||
|
||||
r = requests.get(url, stream=True, headers=headers)
|
||||
with open(output_path, mode) as f:
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
block_size = 1024
|
||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||
for data in r.iter_content(block_size):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
|
||||
|
||||
def sanitize_branch_name(branch_name):
|
||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
if pattern.match(branch_name):
|
||||
return branch_name
|
||||
else:
|
||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||
|
||||
|
||||
def select_model_from_default_options():
|
||||
models = {
|
||||
@@ -110,7 +66,20 @@ EleutherAI/pythia-1.4b-deduped
|
||||
return model, branch
|
||||
|
||||
|
||||
def get_download_links_from_huggingface(model, branch):
|
||||
def sanitize_model_and_branch_names(model, branch):
|
||||
if model[-1] == '/':
|
||||
model = model[:-1]
|
||||
if branch is None:
|
||||
branch = "main"
|
||||
else:
|
||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
if not pattern.match(branch):
|
||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||
|
||||
return model, branch
|
||||
|
||||
|
||||
def get_download_links_from_huggingface(model, branch, text_only=False):
|
||||
base = "https://huggingface.co"
|
||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||
cursor = b""
|
||||
@@ -142,14 +111,14 @@ def get_download_links_from_huggingface(model, branch):
|
||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||
|
||||
if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)):
|
||||
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
|
||||
if 'lfs' in dict[i]:
|
||||
sha256.append([fname, dict[i]['lfs']['oid']])
|
||||
if is_text:
|
||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||
classifications.append('text')
|
||||
continue
|
||||
if not args.text_only:
|
||||
if not text_only:
|
||||
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
|
||||
if is_safetensors:
|
||||
has_safetensors = True
|
||||
@@ -177,80 +146,125 @@ def get_download_links_from_huggingface(model, branch):
|
||||
return links, sha256, is_lora
|
||||
|
||||
|
||||
def download_files(file_list, output_folder, num_threads=8):
|
||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = args.MODEL
|
||||
branch = args.branch
|
||||
if model is None:
|
||||
model, branch = select_model_from_default_options()
|
||||
else:
|
||||
if model[-1] == '/':
|
||||
model = model[:-1]
|
||||
branch = args.branch
|
||||
if branch is None:
|
||||
branch = "main"
|
||||
else:
|
||||
try:
|
||||
branch = sanitize_branch_name(branch)
|
||||
except ValueError as err_branch:
|
||||
print(f"Error: {err_branch}")
|
||||
sys.exit()
|
||||
|
||||
links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
|
||||
|
||||
if args.output is not None:
|
||||
base_folder = args.output
|
||||
else:
|
||||
def get_output_folder(model, branch, is_lora, base_folder=None):
|
||||
if base_folder is None:
|
||||
base_folder = 'models' if not is_lora else 'loras'
|
||||
|
||||
output_folder = f"{'_'.join(model.split('/')[-2:])}"
|
||||
if branch != 'main':
|
||||
output_folder += f'_{branch}'
|
||||
output_folder = Path(base_folder) / output_folder
|
||||
return output_folder
|
||||
|
||||
|
||||
def get_single_file(url, output_folder, start_from_scratch=False):
|
||||
filename = Path(url.rsplit('/', 1)[1])
|
||||
output_path = output_folder / filename
|
||||
if output_path.exists() and not start_from_scratch:
|
||||
# Check if the file has already been downloaded completely
|
||||
r = requests.get(url, stream=True)
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
if output_path.stat().st_size >= total_size:
|
||||
return
|
||||
# Otherwise, resume the download from where it left off
|
||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||
mode = 'ab'
|
||||
else:
|
||||
headers = {}
|
||||
mode = 'wb'
|
||||
|
||||
r = requests.get(url, stream=True, headers=headers)
|
||||
with open(output_path, mode) as f:
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
block_size = 1024
|
||||
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
|
||||
for data in r.iter_content(block_size):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
|
||||
|
||||
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
|
||||
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||
|
||||
|
||||
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
|
||||
# Creating the folder and writing the metadata
|
||||
if not output_folder.exists():
|
||||
output_folder.mkdir()
|
||||
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
||||
f.write(f'url: https://huggingface.co/{model}\n')
|
||||
f.write(f'branch: {branch}\n')
|
||||
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
||||
sha256_str = ''
|
||||
for i in range(len(sha256)):
|
||||
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
||||
if sha256_str != '':
|
||||
f.write(f'sha256sum:\n{sha256_str}')
|
||||
|
||||
# Downloading the files
|
||||
print(f"Downloading the model to {output_folder}")
|
||||
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
|
||||
|
||||
|
||||
def check_model_files(model, branch, links, sha256, output_folder):
|
||||
# Validate the checksums
|
||||
validated = True
|
||||
for i in range(len(sha256)):
|
||||
fpath = (output_folder / sha256[i][0])
|
||||
|
||||
if not fpath.exists():
|
||||
print(f"The following file is missing: {fpath}")
|
||||
validated = False
|
||||
continue
|
||||
|
||||
with open(output_folder / sha256[i][0], "rb") as f:
|
||||
bytes = f.read()
|
||||
file_hash = hashlib.sha256(bytes).hexdigest()
|
||||
if file_hash != sha256[i][1]:
|
||||
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
|
||||
validated = False
|
||||
else:
|
||||
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
|
||||
|
||||
if validated:
|
||||
print('[+] Validated checksums of all model files!')
|
||||
else:
|
||||
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?')
|
||||
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
|
||||
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
branch = args.branch
|
||||
model = args.MODEL
|
||||
if model is None:
|
||||
model, branch = select_model_from_default_options()
|
||||
|
||||
# Cleaning up the model/branch names
|
||||
try:
|
||||
model, branch = sanitize_model_and_branch_names(model, branch)
|
||||
except ValueError as err_branch:
|
||||
print(f"Error: {err_branch}")
|
||||
sys.exit()
|
||||
|
||||
# Getting the download links from Hugging Face
|
||||
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
|
||||
|
||||
# Getting the output folder
|
||||
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
|
||||
|
||||
if args.check:
|
||||
# Validate the checksums
|
||||
validated = True
|
||||
for i in range(len(sha256)):
|
||||
fpath = (output_folder / sha256[i][0])
|
||||
|
||||
if not fpath.exists():
|
||||
print(f"The following file is missing: {fpath}")
|
||||
validated = False
|
||||
continue
|
||||
|
||||
with open(output_folder / sha256[i][0], "rb") as f:
|
||||
bytes = f.read()
|
||||
file_hash = hashlib.sha256(bytes).hexdigest()
|
||||
if file_hash != sha256[i][1]:
|
||||
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
|
||||
validated = False
|
||||
else:
|
||||
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
|
||||
|
||||
if validated:
|
||||
print('[+] Validated checksums of all model files!')
|
||||
else:
|
||||
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
|
||||
|
||||
# Check previously downloaded files
|
||||
check_model_files(model, branch, links, sha256, output_folder)
|
||||
else:
|
||||
|
||||
# Creating the folder and writing the metadata
|
||||
if not output_folder.exists():
|
||||
output_folder.mkdir()
|
||||
with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
|
||||
f.write(f'url: https://huggingface.co/{model}\n')
|
||||
f.write(f'branch: {branch}\n')
|
||||
f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
|
||||
sha256_str = ''
|
||||
for i in range(len(sha256)):
|
||||
sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n'
|
||||
if sha256_str != '':
|
||||
f.write(f'sha256sum:\n{sha256_str}')
|
||||
|
||||
# Downloading the files
|
||||
print(f"Downloading the model to {output_folder}")
|
||||
download_files(links, output_folder, args.threads)
|
||||
# Download files
|
||||
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
|
||||
|
||||
@@ -1,8 +1,23 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
# get the current directory of the script
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# check if the bias_options.txt file exists, if not, create it
|
||||
bias_file = os.path.join(current_dir, "bias_options.txt")
|
||||
if not os.path.isfile(bias_file):
|
||||
with open(bias_file, "w") as f:
|
||||
f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
|
||||
|
||||
# read bias options from the text file
|
||||
with open(bias_file, "r") as f:
|
||||
bias_options = [line.strip() for line in f.readlines()]
|
||||
|
||||
params = {
|
||||
"activate": True,
|
||||
"bias string": " *I am so happy*",
|
||||
"use custom string": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +26,6 @@ def input_modifier(string):
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
|
||||
return string
|
||||
|
||||
|
||||
@@ -19,7 +33,6 @@ def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
return string
|
||||
|
||||
|
||||
@@ -29,9 +42,11 @@ def bot_prefix_modifier(string):
|
||||
the prefix text for the Bot and can be used to bias its
|
||||
behavior.
|
||||
"""
|
||||
|
||||
if params['activate']:
|
||||
return f'{string} {params["bias string"].strip()} '
|
||||
if params['use custom string']:
|
||||
return f'{string} {params["custom string"].strip()} '
|
||||
else:
|
||||
return f'{string} {params["bias string"].strip()} '
|
||||
else:
|
||||
return string
|
||||
|
||||
@@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
|
||||
def ui():
|
||||
# Gradio elements
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||
string = gr.Textbox(value=params["bias string"], label='Character bias')
|
||||
dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
|
||||
use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
|
||||
custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
string.change(lambda x: params.update({"bias string": x}), string, None)
|
||||
def update_bias_string(x):
|
||||
if x:
|
||||
params.update({"bias string": x})
|
||||
else:
|
||||
params.update({"bias string": dropdown_string.get()})
|
||||
return x
|
||||
|
||||
def update_custom_string(x):
|
||||
params.update({"custom string": x})
|
||||
|
||||
dropdown_string.change(update_bias_string, dropdown_string, None)
|
||||
custom_string.change(update_custom_string, custom_string, None)
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
|
||||
|
||||
# Group elements together depending on the selected option
|
||||
def bias_string_group():
|
||||
if use_custom_string.value:
|
||||
return gr.Group([use_custom_string, custom_string])
|
||||
else:
|
||||
return dropdown_string
|
||||
|
||||
@@ -25,7 +25,7 @@ def caption_image(raw_image):
|
||||
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: “{caption_image(picture)}”*'
|
||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
picture.thumbnail((300, 300))
|
||||
buffer = BytesIO()
|
||||
|
||||
@@ -100,10 +100,10 @@ def load_quantized(model_name):
|
||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||
pt_path = None
|
||||
|
||||
if len(found_pts) == 1:
|
||||
pt_path = found_pts[0]
|
||||
elif len(found_safetensors) == 1:
|
||||
pt_path = found_safetensors[0]
|
||||
if len(found_pts) > 0:
|
||||
pt_path = found_pts[-1]
|
||||
elif len(found_safetensors) > 0:
|
||||
pt_path = found_safetensors[-1]
|
||||
else:
|
||||
if path_to_model.name.lower().startswith('llama-7b'):
|
||||
pt_model = f'llama-7b-{shared.args.wbits}bit'
|
||||
@@ -119,13 +119,14 @@ def load_quantized(model_name):
|
||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
if path.exists():
|
||||
print(f"Found {path}")
|
||||
pt_path = path
|
||||
break
|
||||
|
||||
if not pt_path:
|
||||
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||
exit()
|
||||
else:
|
||||
print(f"Found the following quantized model: {pt_path}")
|
||||
|
||||
# qwopqwop200's offload
|
||||
if model_type == 'llama' and shared.args.pre_layer:
|
||||
|
||||
@@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
rows = [f"{context.strip()}\n"]
|
||||
|
||||
@@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
string = shared.history['internal'][i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
||||
@@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
if impersonate:
|
||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||
limit = 2
|
||||
elif _continue:
|
||||
limit = 3
|
||||
else:
|
||||
# Adding the user message
|
||||
user_input = fix_newlines(user_input)
|
||||
@@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
@@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
name1_original = name1
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
@@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
text = apply_extensions(text, "input")
|
||||
if not _continue:
|
||||
text = apply_extensions(text, "input")
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||
kwargs = {
|
||||
'end_of_turn': end_of_turn,
|
||||
'is_instruct': mode == 'instruct',
|
||||
'_continue': _continue
|
||||
}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
if not any((regenerate, _continue)):
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
@@ -154,11 +166,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
return shared.history['visible']
|
||||
if just_started:
|
||||
just_started = False
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
if not _continue:
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
if _continue:
|
||||
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
|
||||
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||
else:
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
if not shared.args.no_stream:
|
||||
yield shared.history['visible']
|
||||
if next_character_found:
|
||||
@@ -220,6 +238,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
else:
|
||||
# Yield ' ...'
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
last = shared.history['visible'].pop()
|
||||
@@ -257,6 +285,9 @@ def clear_chat_log(name1, name2, greeting, mode):
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
@@ -301,15 +332,23 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
return history
|
||||
|
||||
|
||||
def save_history(timestamp=True):
|
||||
if timestamp:
|
||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
def save_history(mode, timestamp=False):
|
||||
# Instruct mode histories should not be saved as if
|
||||
# Alpaca or Vicuna were characters
|
||||
if mode == 'instruct':
|
||||
if not timestamp:
|
||||
return
|
||||
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
else:
|
||||
fname = f"{shared.character}_persistent.json"
|
||||
if timestamp:
|
||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
else:
|
||||
fname = f"{shared.character}_persistent.json"
|
||||
if not Path('logs').exists():
|
||||
Path('logs').mkdir()
|
||||
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||
|
||||
return Path(f'logs/{fname}')
|
||||
|
||||
|
||||
@@ -323,16 +362,6 @@ def load_history(file, name1, name2):
|
||||
shared.history['visible'] = j['data_visible']
|
||||
else:
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
# Compatibility with Pygmalion AI's official web UI
|
||||
elif 'chat' in j:
|
||||
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
shared.history['visible'][0][0] = ''
|
||||
else:
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
except:
|
||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
@@ -368,8 +397,6 @@ def generate_pfp_cache(character):
|
||||
|
||||
def load_character(character, name1, name2, mode):
|
||||
shared.character = character
|
||||
shared.history['internal'] = []
|
||||
shared.history['visible'] = []
|
||||
context = greeting = end_of_turn = ""
|
||||
greeting_field = 'greeting'
|
||||
picture = None
|
||||
@@ -414,13 +441,22 @@ def load_character(character, name1, name2, mode):
|
||||
greeting = shared.settings['greeting']
|
||||
end_of_turn = shared.settings['end_of_turn']
|
||||
|
||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||
elif greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
if mode != 'instruct':
|
||||
shared.history['internal'] = []
|
||||
shared.history['visible'] = []
|
||||
|
||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
|
||||
else:
|
||||
# Insert greeting if it exists
|
||||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
|
||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def load_default_history(name1, name2):
|
||||
|
||||
@@ -164,10 +164,9 @@ def generate_instruct_html(history):
|
||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||
|
||||
# The time.time() is to prevent the brower from caching the image
|
||||
suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
|
||||
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
||||
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
||||
# We use ?name2 and ?time.time() to force the browser to reset caches
|
||||
img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else ''
|
||||
img_me = f'<img src="file/cache/pfp_me.png?{time.time() if reset_cache else ""}">' if Path("cache/pfp_me.png").exists() else ''
|
||||
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
176
modules/llama_attn_hijack.py
Normal file
176
modules/llama_attn_hijack.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import math
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
if shared.args.xformers:
|
||||
try:
|
||||
import xformers.ops
|
||||
except Exception:
|
||||
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
if shared.args.xformers:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
print("Replaced attention with xformers_attention")
|
||||
elif shared.args.sdp_attention:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||
print("Replaced attention with sdp_attention")
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
dtype = query_states.dtype
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||
else:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
#We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
@@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@@ -169,11 +170,23 @@ def load_model(model_name):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
|
||||
# Hijack attention with xformers
|
||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||
llama_attn_hijack.hijack_llama_attention()
|
||||
|
||||
# Loading the tokenizer
|
||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||
# For some people this fixes things, for others it causes an error.
|
||||
try:
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.pad_token_id = 0
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||
tokenizer.truncation_side = 'left'
|
||||
|
||||
@@ -35,6 +35,7 @@ settings = {
|
||||
'greeting': 'Hello there!',
|
||||
'end_of_turn': '',
|
||||
'stop_at_newline': False,
|
||||
'add_bos_token': True,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_prompt_size_min': 0,
|
||||
'chat_prompt_size_max': 2048,
|
||||
@@ -44,7 +45,7 @@ settings = {
|
||||
'default_extensions': [],
|
||||
'chat_default_extensions': ["gallery"],
|
||||
'presets': {
|
||||
'default': 'NovelAI-Sphinx Moth',
|
||||
'default': 'Default',
|
||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||
'.*pygmalion': 'NovelAI-Storywriter',
|
||||
'.*RWKV': 'Naive',
|
||||
@@ -89,7 +90,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
||||
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
||||
@@ -98,6 +99,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
|
||||
parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
|
||||
parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
|
||||
|
||||
# llama.cpp
|
||||
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -21,7 +22,7 @@ def get_max_prompt_length(tokens):
|
||||
return max_length
|
||||
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
@@ -29,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
|
||||
# This is a hack for making replies more creative.
|
||||
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
# Llama adds this extra token when the first character is '\n', and this
|
||||
# compromises the stopping criteria, so we just remove it
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
@@ -63,8 +70,6 @@ def generate_softprompt_input_tensors(input_ids):
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
|
||||
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
@@ -73,8 +78,6 @@ def fix_gpt4chan(s):
|
||||
return s
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
|
||||
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
@@ -101,10 +104,13 @@ def formatted_outputs(reply, model_name):
|
||||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
if seed != -1:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
seed = int(seed)
|
||||
if seed == -1:
|
||||
seed = random.randint(1, 2**31)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
return seed
|
||||
|
||||
|
||||
def stop_everything_event():
|
||||
@@ -113,7 +119,7 @@ def stop_everything_event():
|
||||
|
||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
set_manual_seed(generate_state['seed'])
|
||||
seed = set_manual_seed(generate_state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = {}
|
||||
t0 = time.time()
|
||||
@@ -155,10 +161,10 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
@@ -176,8 +182,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
if shared.args.no_stream:
|
||||
generate_params['min_length'] = 0
|
||||
else:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = generate_state[k]
|
||||
@@ -276,5 +280,5 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
@@ -238,7 +238,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
warmup_steps=100,
|
||||
num_train_epochs=epochs,
|
||||
learning_rate=actual_lr,
|
||||
fp16=True,
|
||||
fp16=False if shared.args.cpu else True,
|
||||
logging_steps=20,
|
||||
evaluation_strategy="steps" if eval_data is not None else "no",
|
||||
save_strategy="steps",
|
||||
@@ -248,7 +248,8 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True if eval_data is not None else False,
|
||||
# TODO: Enable multi-device support
|
||||
ddp_find_unused_parameters=None
|
||||
ddp_find_unused_parameters=None,
|
||||
no_cuda=shared.args.cpu
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||
callbacks=list([Callbacks()])
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
accelerate==0.18.0
|
||||
bitsandbytes==0.37.2
|
||||
datasets
|
||||
flexgen==0.1.7
|
||||
gradio==3.24.1
|
||||
markdown
|
||||
numpy
|
||||
Pillow>=9.5.0
|
||||
peft==0.2.0
|
||||
requests
|
||||
rwkv==0.7.3
|
||||
@@ -13,3 +13,6 @@ sentencepiece
|
||||
pyyaml
|
||||
tqdm
|
||||
git+https://github.com/huggingface/transformers
|
||||
bitsandbytes==0.37.2; platform_system != "Windows"
|
||||
llama-cpp-python==0.1.32; platform_system != "Windows"
|
||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
|
||||
178
server.py
178
server.py
@@ -2,11 +2,14 @@ import os
|
||||
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -21,6 +24,7 @@ from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
from modules.text_generation import generate_reply, stop_everything_event
|
||||
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||
@@ -172,6 +176,34 @@ def create_prompt_menus():
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
|
||||
def download_model_wrapper(repo_id):
|
||||
try:
|
||||
downloader = importlib.import_module("download-model")
|
||||
|
||||
model = repo_id
|
||||
branch = "main"
|
||||
check = False
|
||||
|
||||
yield("Cleaning up the model/branch names")
|
||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||
|
||||
yield("Getting the download links from Hugging Face")
|
||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||
|
||||
yield("Getting the output folder")
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||
|
||||
if check:
|
||||
yield("Checking previously downloaded files")
|
||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||
else:
|
||||
yield(f"Downloading files to {output_folder}")
|
||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||
yield("Done!")
|
||||
except:
|
||||
yield traceback.format_exc()
|
||||
|
||||
|
||||
def create_model_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -182,14 +214,26 @@ def create_model_menus():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA",
|
||||
info="Enter Hugging Face username/model path, e.g: facebook/galactica-125m")
|
||||
with gr.Column():
|
||||
shared.gradio['download_button'] = gr.Button("Download")
|
||||
shared.gradio['download_status'] = gr.Markdown()
|
||||
with gr.Column():
|
||||
pass
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']:
|
||||
generate_params[k] = shared.settings[k]
|
||||
shared.gradio['generate_state'] = gr.State(generate_params)
|
||||
|
||||
@@ -204,18 +248,18 @@ def create_settings_menus(default_preset):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
||||
gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
@@ -229,6 +273,7 @@ def create_settings_menus(default_preset):
|
||||
with gr.Column():
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
@@ -327,8 +372,9 @@ def create_interface():
|
||||
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||
shared.gradio['Continue'] = gr.Button('Continue')
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
with gr.Row():
|
||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||
@@ -339,7 +385,7 @@ def create_interface():
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
|
||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
|
||||
|
||||
with gr.Tab("Character", elem_id="chat-settings"):
|
||||
with gr.Row():
|
||||
@@ -394,56 +440,72 @@ def create_interface():
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
||||
|
||||
def set_chat_input(textbox):
|
||||
return textbox, ""
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear history with confirmation
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
)
|
||||
|
||||
shared.gradio['Replace last reply'].click(
|
||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
|
||||
shared.gradio['Clear history-confirm'].click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
|
||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
||||
|
||||
shared.gradio['Stop'].click(
|
||||
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['Chat mode'].change(
|
||||
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
|
||||
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['Chat mode'], shared.gradio['character_menu']).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['Instruction templates'].change(
|
||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
shared.gradio['upload_chat_history'].upload(
|
||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
|
||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
|
||||
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
||||
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['Chat mode'], shared.gradio['download'])
|
||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||
|
||||
# Clearing stuff and saving the history
|
||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
@@ -477,7 +539,7 @@ def create_interface():
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
else:
|
||||
@@ -511,7 +573,7 @@ def create_interface():
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
with gr.Tab("Model", elem_id="model-tab"):
|
||||
@@ -537,8 +599,10 @@ def create_interface():
|
||||
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
|
||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
|
||||
|
||||
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
|
||||
shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
# Reset interface event
|
||||
shared.gradio['reset_interface'].click(
|
||||
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
|
||||
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
@@ -547,7 +611,7 @@ def create_interface():
|
||||
d[key] = value
|
||||
return d
|
||||
|
||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
|
||||
if k not in shared.gradio:
|
||||
continue
|
||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
"name2": "Assistant",
|
||||
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
|
||||
"greeting": "Hello there!",
|
||||
"end_of_turn": "",
|
||||
"stop_at_newline": false,
|
||||
"add_bos_token": true,
|
||||
"chat_prompt_size": 2048,
|
||||
"chat_prompt_size_min": 0,
|
||||
"chat_prompt_size_max": 2048,
|
||||
@@ -19,7 +21,8 @@
|
||||
"gallery"
|
||||
],
|
||||
"presets": {
|
||||
"default": "NovelAI-Sphinx Moth",
|
||||
"default": "Default",
|
||||
".*(alpaca|llama)": "LLaMA-Precise",
|
||||
".*pygmalion": "NovelAI-Storywriter",
|
||||
".*RWKV": "Naive"
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user