Compare commits
20 Commits
reformat2
...
concatenat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25652f0994 | ||
|
|
0b458bf82d | ||
|
|
ffd102e5c0 | ||
|
|
5543a5089d | ||
|
|
1dc464dcb0 | ||
|
|
962e33dc10 | ||
|
|
42ea6a3fc0 | ||
|
|
e563b015d8 | ||
|
|
1c413ed593 | ||
|
|
3f922d4bfb | ||
|
|
744bf7cbf2 | ||
|
|
768354239b | ||
|
|
6762e62a40 | ||
|
|
a453d4e9c4 | ||
|
|
ec979cd9c4 | ||
|
|
2c0018d946 | ||
|
|
8fa182cfa7 | ||
|
|
862aad637b | ||
|
|
46c4654226 | ||
|
|
ea6e77df72 |
@@ -1,7 +1,6 @@
|
||||
.env
|
||||
Dockerfile
|
||||
/characters
|
||||
/extensions
|
||||
/loras
|
||||
/models
|
||||
/presets
|
||||
|
||||
19
Dockerfile
19
Dockerfile
@@ -30,8 +30,7 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip pip3 install virtualenv
|
||||
|
||||
COPY . /app/
|
||||
RUN mkdir /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -41,21 +40,29 @@ RUN test -n "${WEBUI_VERSION}" && git reset --hard ${WEBUI_VERSION} || echo "Usi
|
||||
RUN virtualenv /app/venv
|
||||
RUN . /app/venv/bin/activate && \
|
||||
pip3 install --upgrade pip setuptools && \
|
||||
pip3 install torch torchvision torchaudio && \
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install torch torchvision torchaudio
|
||||
|
||||
COPY --from=builder /build /app/repositories/GPTQ-for-LLaMa
|
||||
RUN . /app/venv/bin/activate && \
|
||||
pip3 install /app/repositories/GPTQ-for-LLaMa/*.whl
|
||||
|
||||
ENV CLI_ARGS=""
|
||||
|
||||
COPY extensions/api/requirements.txt /app/extensions/api/requirements.txt
|
||||
COPY extensions/elevenlabs_tts/requirements.txt /app/extensions/elevenlabs_tts/requirements.txt
|
||||
COPY extensions/google_translate/requirements.txt /app/extensions/google_translate/requirements.txt
|
||||
COPY extensions/silero_tts/requirements.txt /app/extensions/silero_tts/requirements.txt
|
||||
COPY extensions/whisper_stt/requirements.txt /app/extensions/whisper_stt/requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/api && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/elevenlabs_tts && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/google_translate && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/silero_tts && pip3 install -r requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip . /app/venv/bin/activate && cd extensions/whisper_stt && pip3 install -r requirements.txt
|
||||
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
RUN . /app/venv/bin/activate && \
|
||||
pip3 install -r requirements.txt
|
||||
|
||||
RUN cp /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda118.so /app/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
|
||||
|
||||
COPY . /app/
|
||||
ENV CLI_ARGS=""
|
||||
CMD . /app/venv/bin/activate && python3 server.py ${CLI_ARGS}
|
||||
|
||||
106
README.md
106
README.md
@@ -119,7 +119,7 @@ As an alternative to the recommended WSL method, you can install the web UI nati
|
||||
|
||||
```
|
||||
cp .env.example .env
|
||||
docker-compose up --build
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
Make sure to edit `.env.example` and set the appropriate CUDA version for your GPU.
|
||||
@@ -191,82 +191,82 @@ Optionally, you can use the following command-line flags:
|
||||
|
||||
#### Basic settings
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `-h`, `--help` | show this help message and exit |
|
||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||
| `--chat` | Launch the web UI in chat mode.|
|
||||
| `--model MODEL` | Name of the model to load by default. |
|
||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||
| `--model-dir MODEL_DIR` | Path to directory with all the models |
|
||||
| `--lora-dir LORA_DIR` | Path to directory with all the loras |
|
||||
| `--no-stream` | Don't stream the text output in real time. |
|
||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
|
||||
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
||||
| `--verbose` | Print the prompts to the terminal. |
|
||||
| Flag | Description |
|
||||
|--------------------------------------------|-------------|
|
||||
| `-h`, `--help` | Show this help message and exit. |
|
||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||
| `--chat` | Launch the web UI in chat mode. |
|
||||
| `--model MODEL` | Name of the model to load by default. |
|
||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||
| `--model-dir MODEL_DIR` | Path to directory with all the models. |
|
||||
| `--lora-dir LORA_DIR` | Path to directory with all the loras. |
|
||||
| `--no-stream` | Don't stream the text output in real time. |
|
||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag. |
|
||||
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
||||
| `--verbose` | Print the prompts to the terminal. |
|
||||
|
||||
#### Accelerate/transformers
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--cpu` | Use the CPU to generate text.|
|
||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
|
||||
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
||||
| Flag | Description |
|
||||
|---------------------------------------------|-------------|
|
||||
| `--cpu` | Use the CPU to generate text. |
|
||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU. |
|
||||
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. You can also set values in MiB like `--gpu-memory 3500MiB`. |
|
||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.|
|
||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| `--no-cache` | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
|
||||
|
||||
#### llama.cpp
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--threads` | Number of threads to use in llama.cpp. |
|
||||
| Flag | Description |
|
||||
|-------------|-------------|
|
||||
| `--threads` | Number of threads to use in llama.cpp. |
|
||||
|
||||
#### GPTQ
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
||||
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
||||
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
|
||||
| Flag | Description |
|
||||
|---------------------------|-------------|
|
||||
| `--wbits WBITS` | GPTQ: Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported. |
|
||||
| `--model_type MODEL_TYPE` | GPTQ: Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported. |
|
||||
| `--groupsize GROUPSIZE` | GPTQ: Group size. |
|
||||
| `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. |
|
||||
|
||||
#### FlexGen
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--flexgen` | Enable the use of FlexGen offloading. |
|
||||
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
|
||||
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
|
||||
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
|
||||
| `--flexgen` | Enable the use of FlexGen offloading. |
|
||||
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
|
||||
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
|
||||
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
|
||||
|
||||
#### DeepSpeed
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
||||
| Flag | Description |
|
||||
|---------------------------------------|-------------|
|
||||
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
||||
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
||||
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
||||
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
||||
|
||||
#### RWKV
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
||||
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
||||
| Flag | Description |
|
||||
|---------------------------------|-------------|
|
||||
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
||||
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
||||
|
||||
#### Gradio
|
||||
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `--listen` | Make the web UI reachable from your local network. |
|
||||
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||
| `--auto-launch` | Open the web UI in the default browser upon launch. |
|
||||
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
|
||||
| Flag | Description |
|
||||
|---------------------------------------|-------------|
|
||||
| `--listen` | Make the web UI reachable from your local network. |
|
||||
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||
| `--auto-launch` | Open the web UI in the default browser upon launch. |
|
||||
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
|
||||
|
||||
Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ def random_hash():
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
return ''.join(random.choice(letters) for i in range(9))
|
||||
|
||||
|
||||
async def run(context):
|
||||
server = "127.0.0.1"
|
||||
params = {
|
||||
@@ -41,7 +42,7 @@ async def run(context):
|
||||
|
||||
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
|
||||
while content := json.loads(await websocket.recv()):
|
||||
#Python3.10 syntax, replace with if elif on older
|
||||
# Python3.10 syntax, replace with if elif on older
|
||||
match content["msg"]:
|
||||
case "send_hash":
|
||||
await websocket.send(json.dumps({
|
||||
@@ -62,13 +63,14 @@ async def run(context):
|
||||
pass
|
||||
case "process_generating" | "process_completed":
|
||||
yield content["output"]["data"][0]
|
||||
# You can search for your desired end indicator and
|
||||
# You can search for your desired end indicator and
|
||||
# stop generation by closing the websocket here
|
||||
if (content["msg"] == "process_completed"):
|
||||
break
|
||||
|
||||
prompt = "What I would like to say is the following: "
|
||||
|
||||
|
||||
async def get_result():
|
||||
async for response in run(prompt):
|
||||
# Print intermediate steps
|
||||
|
||||
@@ -13,10 +13,11 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def disable_torch_init():
|
||||
"""
|
||||
Disable the redundant torch default initialization to accelerate model creation.
|
||||
@@ -31,20 +32,22 @@ def disable_torch_init():
|
||||
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
||||
|
||||
|
||||
def restore_torch_init():
|
||||
"""Rollback the change made by disable_torch_init."""
|
||||
import torch
|
||||
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
|
||||
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = Path(args.MODEL)
|
||||
model_name = path.name
|
||||
|
||||
print(f"Loading {model_name}...")
|
||||
#disable_torch_init()
|
||||
# disable_torch_init()
|
||||
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
#restore_torch_init()
|
||||
# restore_torch_init()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
|
||||
parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
|
||||
|
||||
@@ -36,3 +36,8 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
||||
.wrap.svelte-6roggh.svelte-6roggh {
|
||||
max-height: 92.5%;
|
||||
}
|
||||
|
||||
/* This is for the microphone button in the whisper extension */
|
||||
.sm.svelte-1ipelgc {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def get_file(url, output_folder):
|
||||
filename = Path(url.rsplit('/', 1)[1])
|
||||
output_path = output_folder / filename
|
||||
@@ -54,6 +55,7 @@ def get_file(url, output_folder):
|
||||
t.update(len(data))
|
||||
f.write(data)
|
||||
|
||||
|
||||
def sanitize_branch_name(branch_name):
|
||||
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
if pattern.match(branch_name):
|
||||
@@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name):
|
||||
else:
|
||||
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
|
||||
|
||||
|
||||
def select_model_from_default_options():
|
||||
models = {
|
||||
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
|
||||
@@ -78,11 +81,11 @@ def select_model_from_default_options():
|
||||
choices = {}
|
||||
|
||||
print("Select the model that you want to download:\n")
|
||||
for i,name in enumerate(models):
|
||||
char = chr(ord('A')+i)
|
||||
for i, name in enumerate(models):
|
||||
char = chr(ord('A') + i)
|
||||
choices[char] = name
|
||||
print(f"{char}) {name}")
|
||||
char = chr(ord('A')+len(models))
|
||||
char = chr(ord('A') + len(models))
|
||||
print(f"{char}) None of the above")
|
||||
|
||||
print()
|
||||
@@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped
|
||||
|
||||
return model, branch
|
||||
|
||||
|
||||
def get_download_links_from_huggingface(model, branch):
|
||||
base = "https://huggingface.co"
|
||||
page = f"/api/models/{model}/tree/{branch}?cursor="
|
||||
@@ -166,15 +170,17 @@ def get_download_links_from_huggingface(model, branch):
|
||||
|
||||
# If both pytorch and safetensors are available, download safetensors only
|
||||
if (has_pytorch or has_pt) and has_safetensors:
|
||||
for i in range(len(classifications)-1, -1, -1):
|
||||
for i in range(len(classifications) - 1, -1, -1):
|
||||
if classifications[i] in ['pytorch', 'pt']:
|
||||
links.pop(i)
|
||||
|
||||
return links, sha256, is_lora
|
||||
|
||||
|
||||
def download_files(file_list, output_folder, num_threads=8):
|
||||
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = args.MODEL
|
||||
branch = args.branch
|
||||
@@ -224,7 +230,7 @@ if __name__ == '__main__':
|
||||
validated = False
|
||||
else:
|
||||
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
|
||||
|
||||
|
||||
if validated:
|
||||
print('[+] Validated checksums of all model files!')
|
||||
else:
|
||||
|
||||
@@ -9,6 +9,7 @@ params = {
|
||||
'port': 5000,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
@@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
|
||||
@@ -40,18 +41,18 @@ class Handler(BaseHTTPRequestHandler):
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
|
||||
'num_beams': int(body.get('num_beams',1)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
@@ -59,7 +60,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
}
|
||||
|
||||
generator = generate_reply(
|
||||
prompt,
|
||||
prompt,
|
||||
generate_params,
|
||||
stopping_strings=body.get('stopping_strings', []),
|
||||
)
|
||||
@@ -84,9 +85,9 @@ class Handler(BaseHTTPRequestHandler):
|
||||
def run_server():
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
if shared.args.share:
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting KoboldAI compatible api at {public_url}/api')
|
||||
except ImportError:
|
||||
@@ -95,5 +96,6 @@ def run_server():
|
||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def setup():
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
|
||||
@@ -5,14 +5,16 @@ params = {
|
||||
"bias string": " *I am so happy*",
|
||||
}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
"""
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -20,6 +22,7 @@ def output_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
|
||||
behavior.
|
||||
"""
|
||||
|
||||
if params['activate'] == True:
|
||||
if params['activate']:
|
||||
return f'{string} {params["bias string"].strip()} '
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||
|
||||
@@ -2,10 +2,11 @@ import re
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.shared as shared
|
||||
from elevenlabslib import ElevenLabsUser
|
||||
from elevenlabslib.helpers import save_bytes_to_path
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
'api_key': '12345',
|
||||
@@ -20,16 +21,18 @@ user_info = None
|
||||
if not shared.args.no_stream:
|
||||
print("Please add --no-stream. This extension is not meant to be used with streaming.")
|
||||
raise ValueError
|
||||
|
||||
|
||||
# Check if the API is valid and refresh the UI accordingly.
|
||||
|
||||
|
||||
def check_valid_api():
|
||||
|
||||
|
||||
global user, user_info, params
|
||||
|
||||
user = ElevenLabsUser(params['api_key'])
|
||||
user_info = user._get_subscription_data()
|
||||
print('checking api')
|
||||
if params['activate'] == False:
|
||||
if not params['activate']:
|
||||
return gr.update(value='Disconnected')
|
||||
elif user_info is None:
|
||||
print('Incorrect API Key')
|
||||
@@ -37,24 +40,28 @@ def check_valid_api():
|
||||
else:
|
||||
print('Got an API Key!')
|
||||
return gr.update(value='Connected')
|
||||
|
||||
|
||||
# Once the API is verified, get the available voices and update the dropdown list
|
||||
|
||||
|
||||
def refresh_voices():
|
||||
|
||||
|
||||
global user, user_info
|
||||
|
||||
|
||||
your_voices = [None]
|
||||
if user_info is not None:
|
||||
for voice in user.get_available_voices():
|
||||
your_voices.append(voice.initialName)
|
||||
return gr.Dropdown.update(choices=your_voices)
|
||||
return gr.Dropdown.update(choices=your_voices)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
@@ -64,16 +71,17 @@ def input_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
global params, wav_idx, user, user_info
|
||||
|
||||
if params['activate'] == False:
|
||||
|
||||
if not params['activate']:
|
||||
return string
|
||||
elif user_info == None:
|
||||
elif user_info is None:
|
||||
return string
|
||||
|
||||
string = remove_surrounded_chars(string)
|
||||
@@ -84,7 +92,7 @@ def output_modifier(string):
|
||||
|
||||
if string == '':
|
||||
string = 'empty reply, try regenerating'
|
||||
|
||||
|
||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
|
||||
voice = user.get_voices_by_name(params['selected_voice'])[0]
|
||||
audio_data = voice.generate_audio_bytes(string)
|
||||
@@ -94,6 +102,7 @@ def output_modifier(string):
|
||||
wav_idx += 1
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
@@ -110,4 +119,4 @@ def ui():
|
||||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
|
||||
connect.click(check_valid_api, [], connection_status)
|
||||
connect.click(refresh_voices, [], voice)
|
||||
connect.click(refresh_voices, [], voice)
|
||||
|
||||
@@ -85,12 +85,12 @@ def select_character(evt: gr.SelectData):
|
||||
def ui():
|
||||
with gr.Accordion("Character gallery", open=False):
|
||||
update = gr.Button("Refresh")
|
||||
gr.HTML(value="<style>"+generate_css()+"</style>")
|
||||
gr.HTML(value="<style>" + generate_css() + "</style>")
|
||||
gallery = gr.Dataset(components=[gr.HTML(visible=False)],
|
||||
label="",
|
||||
samples=generate_html(),
|
||||
elem_classes=["character-gallery"],
|
||||
samples_per_page=50
|
||||
)
|
||||
label="",
|
||||
samples=generate_html(),
|
||||
elem_classes=["character-gallery"],
|
||||
samples_per_page=50
|
||||
)
|
||||
update.click(generate_html, [], gallery)
|
||||
gallery.select(select_character, None, gradio['character_menu'])
|
||||
gallery.select(select_character, None, gradio['character_menu'])
|
||||
|
||||
@@ -7,14 +7,16 @@ params = {
|
||||
|
||||
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
"""
|
||||
|
||||
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -22,6 +24,7 @@ def output_modifier(string):
|
||||
|
||||
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Finding the language name from the language code to use as the default value
|
||||
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
import gradio as gr
|
||||
import modules.shared as shared
|
||||
import pandas as pd
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||
|
||||
|
||||
def get_prompt_by_name(name):
|
||||
if name == 'None':
|
||||
return ''
|
||||
else:
|
||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||
|
||||
|
||||
def ui():
|
||||
if not shared.is_chat():
|
||||
choices = ['None'] + list(df['Prompt name'])
|
||||
|
||||
78
extensions/sd_api_pictures/README.MD
Normal file
78
extensions/sd_api_pictures/README.MD
Normal file
@@ -0,0 +1,78 @@
|
||||
## Description:
|
||||
TL;DR: Lets the bot answer you with a picture!
|
||||
|
||||
Stable Diffusion API pictures for TextGen, v.1.1.0
|
||||
An extension to [oobabooga's textgen-webui](https://github.com/oobabooga/text-generation-webui) allowing you to receive pics generated by [Automatic1111's SD-WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
||||
|
||||
<details>
|
||||
<summary>Interface overview</summary>
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
Load it in the `--chat` mode with `--extension sd_api_pictures` alongside `send_pictures` (it's not really required, but completes the picture, *pun intended*).
|
||||
|
||||
The image generation is triggered either:
|
||||
- manually through the 'Force the picture response' button while in `Manual` or `Immersive/Interactive` modes OR
|
||||
- automatically in `Immersive/Interactive` mode if the words `'send|main|message|me'` are followed by `'image|pic|picture|photo|snap|snapshot|selfie|meme'` in the user's prompt
|
||||
- always on in Picturebook/Adventure mode (if not currently suppressed by 'Suppress the picture response')
|
||||
|
||||
## Prerequisites
|
||||
|
||||
One needs an available instance of Automatic1111's webui running with an `--api` flag. Ain't tested with a notebook / cloud hosted one but should be possible.
|
||||
To run it locally in parallel on the same machine, specify custom `--listen-port` for either Auto1111's or ooba's webUIs.
|
||||
|
||||
## Features:
|
||||
- API detection (press enter in the API box)
|
||||
- VRAM management (model shuffling)
|
||||
- Three different operation modes (manual, interactive, always-on)
|
||||
- persistent settings via settings.json
|
||||
|
||||
The model input is modified only in the interactive mode; other two are unaffected. The output pic description is presented differently for Picture-book / Adventure mode.
|
||||
|
||||
Connection check (insert the Auto1111's address and press Enter):
|
||||

|
||||
|
||||
### Persistents settings
|
||||
|
||||
Create or modify the `settings.json` in the `text-generation-webui` root directory to override the defaults
|
||||
present in script.py, ex:
|
||||
|
||||
```json
|
||||
{
|
||||
"sd_api_pictures-manage_VRAM": 1,
|
||||
"sd_api_pictures-save_img": 1,
|
||||
"sd_api_pictures-prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful, (solo:1.1)",
|
||||
"sd_api_pictures-sampler_name": "DPM++ 2M Karras"
|
||||
}
|
||||
```
|
||||
|
||||
will automatically set the `Manage VRAM` & `Keep original images` checkboxes and change the texts in `Prompt Prefix` and `Sampler name` on load.
|
||||
|
||||
---
|
||||
|
||||
## Demonstrations:
|
||||
|
||||
Those are examples of the version 1.0.0, but the core functionality is still the same
|
||||
|
||||
<details>
|
||||
<summary>Conversation 1</summary>
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Conversation 2</summary>
|
||||
|
||||

|
||||

|
||||

|
||||
|
||||
</details>
|
||||
|
||||
@@ -1,102 +1,163 @@
|
||||
import base64
|
||||
import io
|
||||
import re
|
||||
import time
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
import requests
|
||||
import torch
|
||||
from modules.models import reload_model, unload_model
|
||||
from PIL import Image
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# parameters which can be customized in settings.json of webui
|
||||
# parameters which can be customized in settings.json of webui
|
||||
params = {
|
||||
'enable_SD_api': False,
|
||||
'address': 'http://127.0.0.1:7860',
|
||||
'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
|
||||
'manage_VRAM': False,
|
||||
'save_img': False,
|
||||
'SD_model': 'NeverEndingDream', # not really used right now
|
||||
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
|
||||
'SD_model': 'NeverEndingDream', # not used right now
|
||||
'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
|
||||
'negative_prompt': '(worst quality, low quality:1.3)',
|
||||
'side_length': 512,
|
||||
'restore_faces': False
|
||||
'width': 512,
|
||||
'height': 512,
|
||||
'restore_faces': False,
|
||||
'seed': -1,
|
||||
'sampler_name': 'DDIM',
|
||||
'steps': 32,
|
||||
'cfg_scale': 7
|
||||
}
|
||||
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
pic_id = 0
|
||||
def give_VRAM_priority(actor):
|
||||
global shared, params
|
||||
|
||||
if actor == 'SD':
|
||||
unload_model()
|
||||
print("Requesting Auto1111 to re-load last checkpoint used...")
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||
response.raise_for_status()
|
||||
|
||||
elif actor == 'LLM':
|
||||
print("Requesting Auto1111 to vacate VRAM...")
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||
response.raise_for_status()
|
||||
reload_model()
|
||||
|
||||
elif actor == 'set':
|
||||
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
|
||||
response.raise_for_status()
|
||||
|
||||
elif actor == 'reset':
|
||||
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
|
||||
response.raise_for_status()
|
||||
|
||||
else:
|
||||
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
|
||||
|
||||
response.raise_for_status()
|
||||
del response
|
||||
|
||||
|
||||
if params['manage_VRAM']:
|
||||
give_VRAM_priority('set')
|
||||
|
||||
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def triggers_are_in(string):
|
||||
string = remove_surrounded_chars(string)
|
||||
# regex searches for send|main|message|me (at the end of the word) followed by
|
||||
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
|
||||
# (?aims) are regex parser flags
|
||||
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
|
||||
|
||||
|
||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
global params, picture_response
|
||||
if not params['enable_SD_api']:
|
||||
|
||||
global params
|
||||
|
||||
if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
|
||||
return string
|
||||
|
||||
commands = ['send', 'mail', 'me']
|
||||
mediums = ['image', 'pic', 'picture', 'photo']
|
||||
subjects = ['yourself', 'own']
|
||||
lowstr = string.lower()
|
||||
|
||||
# TODO: refactor out to separate handler and also replace detection with a regexp
|
||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||
picture_response = True
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||
shared.processing_message = "*Is sending a picture...*"
|
||||
string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
|
||||
if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
|
||||
string = "Please provide a detailed and vivid description of how you look and what you are wearing"
|
||||
if triggers_are_in(string): # if we're in it, check for trigger words
|
||||
toggle_generation(True)
|
||||
string = string.lower()
|
||||
if "of" in string:
|
||||
subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
|
||||
string = "Please provide a detailed and vivid description of " + subject
|
||||
else:
|
||||
string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now"
|
||||
|
||||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params, pic_id
|
||||
global params
|
||||
|
||||
if params['manage_VRAM']:
|
||||
give_VRAM_priority('SD')
|
||||
|
||||
payload = {
|
||||
"prompt": params['prompt_prefix'] + description,
|
||||
"seed": -1,
|
||||
"sampler_name": "DPM++ 2M Karras",
|
||||
"steps": 32,
|
||||
"cfg_scale": 7,
|
||||
"width": params['side_length'],
|
||||
"height": params['side_length'],
|
||||
"seed": params['seed'],
|
||||
"sampler_name": params['sampler_name'],
|
||||
"steps": params['steps'],
|
||||
"cfg_scale": params['cfg_scale'],
|
||||
"width": params['width'],
|
||||
"height": params['height'],
|
||||
"restore_faces": params['restore_faces'],
|
||||
"negative_prompt": params['negative_prompt']
|
||||
}
|
||||
|
||||
|
||||
print(f'Prompting the image generator via the API on {params["address"]}...')
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
||||
response.raise_for_status()
|
||||
r = response.json()
|
||||
|
||||
visible_result = ""
|
||||
for img_str in r['images']:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
||||
if params['save_img']:
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
||||
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
image.save(output_file.as_posix())
|
||||
pic_id += 1
|
||||
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
image.thumbnail((300, 300))
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
buffered.seek(0)
|
||||
image_bytes = buffered.getvalue()
|
||||
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
||||
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
||||
|
||||
visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
|
||||
else:
|
||||
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
image.thumbnail((300, 300))
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
buffered.seek(0)
|
||||
image_bytes = buffered.getvalue()
|
||||
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
||||
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
||||
|
||||
if params['manage_VRAM']:
|
||||
give_VRAM_priority('LLM')
|
||||
|
||||
return visible_result
|
||||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
@@ -105,7 +166,8 @@ def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
global pic_id, picture_response, streaming_state
|
||||
|
||||
global picture_response, params
|
||||
|
||||
if not picture_response:
|
||||
return string
|
||||
@@ -118,17 +180,19 @@ def output_modifier(string):
|
||||
|
||||
if string == '':
|
||||
string = 'no viable description in reply, try regenerating'
|
||||
return string
|
||||
|
||||
# I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this
|
||||
text = f'*Description: "{string}"*'
|
||||
text = ""
|
||||
if (params['mode'] < 2):
|
||||
toggle_generation(False)
|
||||
text = f'*Sends a picture which portrays: “{string}”*'
|
||||
else:
|
||||
text = string
|
||||
|
||||
image = get_SD_pictures(string)
|
||||
string = get_SD_pictures(string) + "\n" + text
|
||||
|
||||
picture_response = False
|
||||
return string
|
||||
|
||||
shared.processing_message = "*Is typing...*"
|
||||
shared.args.no_stream = streaming_state
|
||||
return image + "\n" + text
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
@@ -139,41 +203,92 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
def force_pic():
|
||||
global picture_response
|
||||
picture_response = True
|
||||
|
||||
def toggle_generation(*args):
|
||||
global picture_response, shared, streaming_state
|
||||
|
||||
if not args:
|
||||
picture_response = not picture_response
|
||||
else:
|
||||
picture_response = args[0]
|
||||
|
||||
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
|
||||
|
||||
|
||||
def filter_address(address):
|
||||
address = address.strip()
|
||||
# address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
|
||||
address = re.sub('\/$', '', address) # remove trailing /s
|
||||
if not address.startswith('http'):
|
||||
address = 'http://' + address
|
||||
return address
|
||||
|
||||
|
||||
def SD_api_address_update(address):
|
||||
|
||||
global params
|
||||
|
||||
msg = "✔️ SD API is found on:"
|
||||
address = filter_address(address)
|
||||
params.update({"address": address})
|
||||
try:
|
||||
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
|
||||
response.raise_for_status()
|
||||
# r = response.json()
|
||||
except:
|
||||
msg = "❌ No SD API endpoint on:"
|
||||
|
||||
return gr.Textbox.update(label=msg)
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
with gr.Accordion("Stable Diffusion api integration", open=True):
|
||||
# gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
|
||||
with gr.Accordion("Parameters", open=True):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration')
|
||||
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
|
||||
with gr.Column():
|
||||
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
|
||||
|
||||
with gr.Row():
|
||||
force_btn = gr.Button("Force the next response to be a picture")
|
||||
generate_now_btn = gr.Button("Generate an image response to the input")
|
||||
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
|
||||
mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index")
|
||||
with gr.Column(scale=1, min_width=300):
|
||||
manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
|
||||
save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
|
||||
|
||||
force_pic = gr.Button("Force the picture response")
|
||||
suppr_pic = gr.Button("Suppress the picture response")
|
||||
|
||||
with gr.Accordion("Generation parameters", open=False):
|
||||
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
|
||||
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
|
||||
|
||||
with gr.Column():
|
||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||
sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler')
|
||||
with gr.Column():
|
||||
width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
|
||||
height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
|
||||
with gr.Row():
|
||||
steps = gr.Number(label="Steps:", value=params['steps'])
|
||||
seed = gr.Number(label="Seed:", value=params['seed'])
|
||||
cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
|
||||
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
||||
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
|
||||
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
|
||||
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
|
||||
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
||||
address.change(lambda x: params.update({"address": x}), address, None)
|
||||
|
||||
address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
|
||||
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
|
||||
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
|
||||
dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None)
|
||||
# model.change(lambda x: params.update({"SD_model": x}), model, None)
|
||||
width.change(lambda x: params.update({"width": x}), width, None)
|
||||
height.change(lambda x: params.update({"height": x}), height, None)
|
||||
|
||||
force_btn.click(force_pic)
|
||||
generate_now_btn.click(force_pic)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
|
||||
steps.change(lambda x: params.update({"steps": x}), steps, None)
|
||||
seed.change(lambda x: params.update({"seed": x}), seed, None)
|
||||
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
|
||||
|
||||
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
|
||||
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)
|
||||
|
||||
@@ -17,11 +17,13 @@ input_hijack = {
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||
|
||||
|
||||
def caption_image(raw_image):
|
||||
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||
out = model.generate(**inputs, max_new_tokens=100)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
@@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
|
||||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def ui():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
|
||||
@@ -42,4 +45,4 @@ def ui():
|
||||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear the picture from the upload field
|
||||
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
||||
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
ipython
|
||||
num2words
|
||||
omegaconf
|
||||
pydub
|
||||
PyYAML
|
||||
torch
|
||||
torchaudio
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.chat as chat
|
||||
import modules.shared as shared
|
||||
import torch
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
'speaker': 'en_56',
|
||||
@@ -20,13 +22,14 @@ params = {
|
||||
'autoplay': True,
|
||||
'voice_pitch': 'medium',
|
||||
'voice_speed': 'medium',
|
||||
'local_cache_path': '' # User can override the default cache path to something other via settings.json
|
||||
}
|
||||
|
||||
current_params = params.copy()
|
||||
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
|
||||
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||
table = str.maketrans({
|
||||
@@ -37,26 +40,31 @@ table = str.maketrans({
|
||||
'"': """,
|
||||
})
|
||||
|
||||
|
||||
def xmlesc(txt):
|
||||
return txt.translate(table)
|
||||
|
||||
|
||||
def load_model():
|
||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
torch_cache_path = torch.hub.get_dir() if params['local_cache_path'] == '' else params['local_cache_path']
|
||||
model_path = torch_cache_path + "/snakers4_silero-models_master/src/silero/model/" + params['model_id'] + ".pt"
|
||||
if Path(model_path).is_file():
|
||||
print(f'\nUsing Silero TTS cached checkpoint found at {torch_cache_path}')
|
||||
model, example_text = torch.hub.load(repo_or_dir=torch_cache_path + '/snakers4_silero-models_master/', model='silero_tts', language=params['language'], speaker=params['model_id'], source='local', path=model_path, force_reload=True)
|
||||
else:
|
||||
print(f'\nSilero TTS cache not found at {torch_cache_path}. Attempting to download...')
|
||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
model.to(params['device'])
|
||||
return model
|
||||
model = load_model()
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
|
||||
def remove_tts_from_history(name1, name2):
|
||||
def remove_tts_from_history(name1, name2, mode):
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
def toggle_text_in_history(name1, name2):
|
||||
|
||||
def toggle_text_in_history(name1, name2, mode):
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
visible_reply = entry[1]
|
||||
if visible_reply.startswith('<audio'):
|
||||
@@ -65,7 +73,8 @@ def toggle_text_in_history(name1, name2):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
else:
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
@@ -75,12 +84,13 @@ def input_modifier(string):
|
||||
|
||||
# Remove autoplay from the last reply
|
||||
if shared.is_chat() and len(shared.history['internal']) > 0:
|
||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
|
||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
|
||||
|
||||
shared.processing_message = "*Is recording a voice message...*"
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
@@ -94,15 +104,11 @@ def output_modifier(string):
|
||||
current_params = params.copy()
|
||||
break
|
||||
|
||||
if params['activate'] == False:
|
||||
if not params['activate']:
|
||||
return string
|
||||
|
||||
original_string = string
|
||||
string = remove_surrounded_chars(string)
|
||||
string = string.replace('"', '')
|
||||
string = string.replace('“', '')
|
||||
string = string.replace('\n', ' ')
|
||||
string = string.strip()
|
||||
string = tts_preprocessor.preprocess(string)
|
||||
|
||||
if string == '':
|
||||
string = '*Empty reply, try regenerating*'
|
||||
@@ -118,9 +124,10 @@ def output_modifier(string):
|
||||
string += f'\n\n{original_string}'
|
||||
|
||||
shared.processing_message = "*Is typing...*"
|
||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
|
||||
return string
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
@@ -130,17 +137,25 @@ def bot_prefix_modifier(string):
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def setup():
|
||||
global model
|
||||
model = load_model()
|
||||
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
with gr.Accordion("Silero TTS"):
|
||||
with gr.Row():
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
|
||||
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
|
||||
|
||||
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
|
||||
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
|
||||
with gr.Row():
|
||||
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
|
||||
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
|
||||
|
||||
with gr.Row():
|
||||
convert = gr.Button('Permanently replace audios with the message texts')
|
||||
convert_cancel = gr.Button('Cancel', visible=False)
|
||||
@@ -148,20 +163,20 @@ def ui():
|
||||
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
||||
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
|
||||
# Toggle message text in history
|
||||
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
|
||||
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
||||
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
|
||||
voice.change(lambda x: params.update({"speaker": x}), voice, None)
|
||||
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
|
||||
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
||||
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
|
||||
|
||||
81
extensions/silero_tts/test_tts.py
Normal file
81
extensions/silero_tts/test_tts.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tts_preprocessor
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
'speaker': 'en_49',
|
||||
'language': 'en',
|
||||
'model_id': 'v3_en',
|
||||
'sample_rate': 48000,
|
||||
'device': 'cpu',
|
||||
'show_text': True,
|
||||
'autoplay': True,
|
||||
'voice_pitch': 'medium',
|
||||
'voice_speed': 'medium',
|
||||
}
|
||||
|
||||
current_params = params.copy()
|
||||
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
|
||||
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
|
||||
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
|
||||
|
||||
# Used for making text xml compatible, needed for voice pitch and speed control
|
||||
table = str.maketrans({
|
||||
"<": "<",
|
||||
">": ">",
|
||||
"&": "&",
|
||||
"'": "'",
|
||||
'"': """,
|
||||
})
|
||||
|
||||
|
||||
def xmlesc(txt):
|
||||
return txt.translate(table)
|
||||
|
||||
|
||||
def load_model():
|
||||
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
|
||||
model.to(params['device'])
|
||||
return model
|
||||
|
||||
|
||||
model = load_model()
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
global model, current_params
|
||||
|
||||
original_string = string
|
||||
string = tts_preprocessor.preprocess(string)
|
||||
processed_string = string
|
||||
|
||||
if string == '':
|
||||
string = '*Empty reply, try regenerating*'
|
||||
else:
|
||||
output_file = Path(f'extensions/silero_tts/outputs/test_{int(time.time())}.wav')
|
||||
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
|
||||
autoplay = 'autoplay' if params['autoplay'] else ''
|
||||
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
|
||||
|
||||
if params['show_text']:
|
||||
string += f'\n\n{original_string}\n\nProcessed:\n{processed_string}'
|
||||
|
||||
print(string)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
output_modifier(sys.argv[1])
|
||||
194
extensions/silero_tts/tts_preprocessor.py
Normal file
194
extensions/silero_tts/tts_preprocessor.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import re
|
||||
|
||||
from num2words import num2words
|
||||
|
||||
punctuation = r'[\s,.?!/)\'\]>]'
|
||||
alphabet_map = {
|
||||
"A": " Ei ",
|
||||
"B": " Bee ",
|
||||
"C": " See ",
|
||||
"D": " Dee ",
|
||||
"E": " Eee ",
|
||||
"F": " Eff ",
|
||||
"G": " Jee ",
|
||||
"H": " Eich ",
|
||||
"I": " Eye ",
|
||||
"J": " Jay ",
|
||||
"K": " Kay ",
|
||||
"L": " El ",
|
||||
"M": " Emm ",
|
||||
"N": " Enn ",
|
||||
"O": " Ohh ",
|
||||
"P": " Pee ",
|
||||
"Q": " Queue ",
|
||||
"R": " Are ",
|
||||
"S": " Ess ",
|
||||
"T": " Tee ",
|
||||
"U": " You ",
|
||||
"V": " Vee ",
|
||||
"W": " Double You ",
|
||||
"X": " Ex ",
|
||||
"Y": " Why ",
|
||||
"Z": " Zed " # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
|
||||
}
|
||||
|
||||
|
||||
def preprocess(string):
|
||||
# the order for some of these matter
|
||||
# For example, you need to remove the commas in numbers before expanding them
|
||||
string = remove_surrounded_chars(string)
|
||||
string = string.replace('"', '')
|
||||
string = string.replace('\u201D', '').replace('\u201C', '') # right and left quote
|
||||
string = string.replace('\u201F', '') # italic looking quote
|
||||
string = string.replace('\n', ' ')
|
||||
string = convert_num_locale(string)
|
||||
string = replace_negative(string)
|
||||
string = replace_roman(string)
|
||||
string = hyphen_range_to(string)
|
||||
string = num_to_words(string)
|
||||
|
||||
# TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
|
||||
# try to say the abbreviation or spell it out as I've done below is not agreed upon
|
||||
|
||||
# For now, expand abbreviations to pronunciations
|
||||
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation
|
||||
string = replace_abbreviations(string)
|
||||
string = replace_lowercase_abbreviations(string)
|
||||
|
||||
# cleanup whitespaces
|
||||
# remove whitespace before punctuation
|
||||
string = re.sub(rf'\s+({punctuation})', r'\1', string)
|
||||
string = string.strip()
|
||||
# compact whitespace
|
||||
string = ' '.join(string.split())
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub(r'\*[^*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def convert_num_locale(text):
|
||||
# This detects locale and converts it to American without comma separators
|
||||
pattern = re.compile(r'(?:\s|^)\d{1,3}(?:\.\d{3})+(,\d+)(?:\s|$)')
|
||||
result = text
|
||||
while True:
|
||||
match = pattern.search(result)
|
||||
if match is None:
|
||||
break
|
||||
|
||||
start = match.start()
|
||||
end = match.end()
|
||||
result = result[0:start] + result[start:end].replace('.', '').replace(',', '.') + result[end:len(result)]
|
||||
|
||||
# removes comma separators from existing American numbers
|
||||
pattern = re.compile(r'(\d),(\d)')
|
||||
result = pattern.sub(r'\1\2', result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replace_negative(string):
|
||||
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five
|
||||
return re.sub(rf'(\s)(-)(\d+)({punctuation})', r'\1negative \3\4', string)
|
||||
|
||||
|
||||
def replace_roman(string):
|
||||
# find a string of roman numerals.
|
||||
# Only 2 or more, to avoid capturing I and single character abbreviations, like names
|
||||
pattern = re.compile(rf'\s[IVXLCDM]{{2,}}{punctuation}')
|
||||
result = string
|
||||
while True:
|
||||
match = pattern.search(result)
|
||||
if match is None:
|
||||
break
|
||||
|
||||
start = match.start()
|
||||
end = match.end()
|
||||
result = result[0:start + 1] + str(roman_to_int(result[start + 1:end - 1])) + result[end - 1:len(result)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def roman_to_int(s):
|
||||
rom_val = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
|
||||
int_val = 0
|
||||
for i in range(len(s)):
|
||||
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]:
|
||||
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]]
|
||||
else:
|
||||
int_val += rom_val[s[i]]
|
||||
return int_val
|
||||
|
||||
|
||||
def hyphen_range_to(text):
|
||||
pattern = re.compile(r'(\d+)[-–](\d+)')
|
||||
result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
|
||||
return result
|
||||
|
||||
|
||||
def num_to_words(text):
|
||||
# 1000 or 10.23
|
||||
pattern = re.compile(r'\d+\.\d+|\d+')
|
||||
result = pattern.sub(lambda x: num2words(float(x.group())), text)
|
||||
return result
|
||||
|
||||
|
||||
def replace_abbreviations(string):
|
||||
# abbreviations 1 to 4 characters long. It will get things like A and I, but those are pronounced with their letter
|
||||
pattern = re.compile(rf'(^|[\s(.\'\[<])([A-Z]{{1,4}})({punctuation}|$)')
|
||||
result = string
|
||||
while True:
|
||||
match = pattern.search(result)
|
||||
if match is None:
|
||||
break
|
||||
|
||||
start = match.start()
|
||||
end = match.end()
|
||||
result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replace_lowercase_abbreviations(string):
|
||||
# abbreviations 1 to 4 characters long, separated by dots i.e. e.g.
|
||||
pattern = re.compile(rf'(^|[\s(.\'\[<])(([a-z]\.){{1,4}})({punctuation}|$)')
|
||||
result = string
|
||||
while True:
|
||||
match = pattern.search(result)
|
||||
if match is None:
|
||||
break
|
||||
|
||||
start = match.start()
|
||||
end = match.end()
|
||||
result = result[0:start] + replace_abbreviation(result[start:end].upper()) + result[end:len(result)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def replace_abbreviation(string):
|
||||
result = ""
|
||||
for char in string:
|
||||
result += match_mapping(char)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def match_mapping(char):
|
||||
for mapping in alphabet_map.keys():
|
||||
if char == mapping:
|
||||
return alphabet_map[char]
|
||||
|
||||
return char
|
||||
|
||||
|
||||
def __main__(args):
|
||||
print(preprocess(args[1]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
__main__(sys.argv)
|
||||
@@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
import speech_recognition as sr
|
||||
from modules import shared
|
||||
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
@@ -7,7 +8,7 @@ input_hijack = {
|
||||
}
|
||||
|
||||
|
||||
def do_stt(audio, text_state=""):
|
||||
def do_stt(audio):
|
||||
transcription = ""
|
||||
r = sr.Recognizer()
|
||||
|
||||
@@ -21,34 +22,23 @@ def do_stt(audio, text_state=""):
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
|
||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||
|
||||
text_state += transcription + " "
|
||||
return text_state, text_state
|
||||
return transcription
|
||||
|
||||
|
||||
def update_hijack(val):
|
||||
input_hijack.update({"state": True, "value": [val, val]})
|
||||
return val
|
||||
|
||||
|
||||
def auto_transcribe(audio, audio_auto, text_state=""):
|
||||
def auto_transcribe(audio, auto_submit):
|
||||
if audio is None:
|
||||
return "", ""
|
||||
if audio_auto:
|
||||
return do_stt(audio, text_state)
|
||||
return "", ""
|
||||
|
||||
transcription = do_stt(audio)
|
||||
if auto_submit:
|
||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||
|
||||
return transcription, None
|
||||
|
||||
|
||||
def ui():
|
||||
tr_state = gr.State(value="")
|
||||
output_transcription = gr.Textbox(label="STT-Input",
|
||||
placeholder="Speech Preview. Click \"Generate\" to send",
|
||||
interactive=True)
|
||||
output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
|
||||
audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
|
||||
with gr.Row():
|
||||
audio = gr.Audio(source="microphone")
|
||||
audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
|
||||
transcribe_button = gr.Button(value="Transcribe")
|
||||
transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])
|
||||
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=True)
|
||||
audio.change(fn=auto_transcribe, inputs=[audio, auto_submit], outputs=[shared.gradio['textbox'], audio])
|
||||
audio.change(None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
|
||||
|
||||
@@ -17,9 +17,11 @@ from quant import make_quant
|
||||
|
||||
|
||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
@@ -34,11 +36,11 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
for name in exclude_layers:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
@@ -48,7 +50,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
|
||||
|
||||
make_quant(**make_quant_kwargs)
|
||||
|
||||
del layers
|
||||
@@ -56,14 +58,15 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
||||
print('Loading model ...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint), strict = False)
|
||||
model.load_state_dict(safe_load(checkpoint), strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint), strict = False)
|
||||
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_quantized(model_name):
|
||||
if not shared.args.model_type:
|
||||
# Try to determine model type from model name
|
||||
@@ -114,7 +117,7 @@ def load_quantized(model_name):
|
||||
pt_model = f'{model_name}-{shared.args.wbits}bit'
|
||||
|
||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
if path.exists():
|
||||
print(f"Found {path}")
|
||||
pt_path = path
|
||||
@@ -133,7 +136,7 @@ def load_quantized(model_name):
|
||||
|
||||
# accelerate offload (doesn't work properly)
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
|
||||
@@ -4,15 +4,9 @@ import torch
|
||||
from peft import PeftModel
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import load_model
|
||||
from modules.text_generation import clear_torch_cache
|
||||
from modules.models import reload_model
|
||||
|
||||
|
||||
def reload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
def add_lora_to_model(lora_name):
|
||||
|
||||
# If a LoRA had been previously loaded, or if we want
|
||||
@@ -27,10 +21,10 @@ def add_lora_to_model(lora_name):
|
||||
if not shared.args.cpu:
|
||||
params['dtype'] = shared.model.dtype
|
||||
if hasattr(shared.model, "hf_device_map"):
|
||||
params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
|
||||
params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
|
||||
elif shared.args.load_in_8bit:
|
||||
params['device_map'] = {'': 0}
|
||||
|
||||
|
||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
|
||||
if not shared.args.load_in_8bit and not shared.args.cpu:
|
||||
shared.model.half()
|
||||
|
||||
@@ -10,7 +10,7 @@ from modules.callbacks import Iteratorize
|
||||
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
||||
|
||||
os.environ['RWKV_JIT_ON'] = '1'
|
||||
os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
|
||||
os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
|
||||
|
||||
from rwkv.model import RWKV
|
||||
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
||||
@@ -36,13 +36,13 @@ class RWKVModel:
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
|
||||
args = PIPELINE_ARGS(
|
||||
temperature = temperature,
|
||||
top_p = top_p,
|
||||
top_k = top_k,
|
||||
alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban = token_ban, # ban the generation of some tokens
|
||||
token_stop = token_stop
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban=token_ban, # ban the generation of some tokens
|
||||
token_stop=token_stop
|
||||
)
|
||||
|
||||
return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||
@@ -54,6 +54,7 @@ class RWKVModel:
|
||||
reply += token
|
||||
yield reply
|
||||
|
||||
|
||||
class RWKVTokenizer:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
|
||||
for i in generate_reply(params[0], generate_params):
|
||||
yield i
|
||||
|
||||
|
||||
def create_apis():
|
||||
t1 = gr.Textbox(visible=False)
|
||||
t2 = gr.Textbox(visible=False)
|
||||
|
||||
@@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
@@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
|
||||
"""
|
||||
@@ -47,8 +49,8 @@ class Iteratorize:
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc=func
|
||||
self.c_callback=callback
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
@@ -80,7 +82,7 @@ class Iteratorize:
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True,None)
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
@@ -96,6 +98,7 @@ class Iteratorize:
|
||||
self.stop_now = True
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
|
||||
@@ -12,7 +12,7 @@ from PIL import Image
|
||||
import modules.extensions as extensions_module
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import (fix_newlines, chat_html_wrapper,
|
||||
from modules.html_generator import (chat_html_wrapper, fix_newlines,
|
||||
make_thumbnail)
|
||||
from modules.text_generation import (encode, generate_reply,
|
||||
get_max_prompt_length)
|
||||
@@ -23,12 +23,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
|
||||
rows = [f"{context.strip()}\n"]
|
||||
|
||||
# Finding the maximum prompt size
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
||||
|
||||
if is_instruct:
|
||||
@@ -38,7 +37,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
prefix1 = f"{name1}: "
|
||||
prefix2 = f"{name2}: "
|
||||
|
||||
i = len(shared.history['internal'])-1
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
string = shared.history['internal'][i][0]
|
||||
@@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
next_character_found = False
|
||||
|
||||
@@ -87,7 +87,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
# is completed, trim it
|
||||
if not next_character_found:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for j in range(len(string)-1, 0, -1):
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
if reply[-j:] == string[:j]:
|
||||
reply = reply[:-j]
|
||||
break
|
||||
@@ -98,22 +98,25 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
reply = fix_newlines(reply)
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
just_started = True
|
||||
name1_original = name1
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
visible_text = None
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
@@ -123,6 +126,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
visible_text = text
|
||||
text = apply_extensions(text, "input")
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
@@ -131,11 +135,9 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
yield shared.history['visible']+[[visible_text, shared.processing_message]]
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
cumulative_reply = ''
|
||||
just_started = True
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
@@ -167,12 +169,15 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
yield shared.history['visible']
|
||||
|
||||
|
||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
@@ -182,7 +187,6 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
|
||||
cumulative_reply = ''
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
@@ -197,22 +201,25 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
||||
|
||||
yield reply
|
||||
|
||||
|
||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
yield chat_html_wrapper(history, name1, name2, mode)
|
||||
|
||||
|
||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
last = shared.history['visible'].pop()
|
||||
@@ -222,12 +229,14 @@ def remove_last_message(name1, name2, mode):
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
|
||||
|
||||
|
||||
def send_last_reply_to_input():
|
||||
if len(shared.history['internal']) > 0:
|
||||
return shared.history['internal'][-1][1]
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def replace_last_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0:
|
||||
shared.history['visible'][-1][1] = text
|
||||
@@ -235,9 +244,11 @@ def replace_last_reply(text, name1, name2, mode):
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def clear_html():
|
||||
return chat_html_wrapper([], "", "")
|
||||
|
||||
|
||||
def clear_chat_log(name1, name2, greeting, mode):
|
||||
shared.history['visible'] = []
|
||||
shared.history['internal'] = []
|
||||
@@ -248,12 +259,14 @@ def clear_chat_log(name1, name2, greeting, mode):
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def redraw_html(name1, name2, mode):
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
history = []
|
||||
|
||||
messages = []
|
||||
dialogue = re.sub('<START>', '', dialogue)
|
||||
dialogue = re.sub('<start>', '', dialogue)
|
||||
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
|
||||
@@ -262,9 +275,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
if len(idx) == 0:
|
||||
return history
|
||||
|
||||
messages = []
|
||||
for i in range(len(idx)-1):
|
||||
messages.append(dialogue[idx[i]:idx[i+1]].strip())
|
||||
for i in range(len(idx) - 1):
|
||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||
messages.append(dialogue[idx[-1]:].strip())
|
||||
|
||||
entry = ['', '']
|
||||
@@ -282,12 +294,13 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
for column in row:
|
||||
print("\n")
|
||||
for line in column.strip().split('\n'):
|
||||
print("| "+line+"\n")
|
||||
print("| " + line + "\n")
|
||||
print("|\n")
|
||||
print("------------------------------")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def save_history(timestamp=True):
|
||||
if timestamp:
|
||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
@@ -299,6 +312,7 @@ def save_history(timestamp=True):
|
||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||
return Path(f'logs/{fname}')
|
||||
|
||||
|
||||
def load_history(file, name1, name2):
|
||||
file = file.decode('utf-8')
|
||||
try:
|
||||
@@ -313,20 +327,22 @@ def load_history(file, name1, name2):
|
||||
elif 'chat' in j:
|
||||
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
|
||||
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
|
||||
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
shared.history['visible'][0][0] = ''
|
||||
else:
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
|
||||
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
except:
|
||||
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
|
||||
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
|
||||
|
||||
|
||||
def replace_character_names(text, name1, name2):
|
||||
text = text.replace('{{user}}', name1).replace('{{char}}', name2)
|
||||
return text.replace('<USER>', name1).replace('<BOT>', name2)
|
||||
|
||||
|
||||
def build_pygmalion_style_context(data):
|
||||
context = ""
|
||||
if 'char_persona' in data and data['char_persona'] != '':
|
||||
@@ -336,6 +352,7 @@ def build_pygmalion_style_context(data):
|
||||
context = f"{context.strip()}\n<START>\n"
|
||||
return context
|
||||
|
||||
|
||||
def generate_pfp_cache(character):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
@@ -348,6 +365,7 @@ def generate_pfp_cache(character):
|
||||
return img
|
||||
return None
|
||||
|
||||
|
||||
def load_character(character, name1, name2, mode):
|
||||
shared.character = character
|
||||
shared.history['internal'] = []
|
||||
@@ -387,13 +405,13 @@ def load_character(character, name1, name2, mode):
|
||||
if 'example_dialogue' in data:
|
||||
context += f"{data['example_dialogue'].strip()}\n"
|
||||
if greeting_field in data:
|
||||
greeting = data[greeting_field]
|
||||
greeting = data[greeting_field]
|
||||
if 'end_of_turn' in data:
|
||||
end_of_turn = data['end_of_turn']
|
||||
end_of_turn = data['end_of_turn']
|
||||
else:
|
||||
context = shared.settings['context']
|
||||
name2 = shared.settings['name2']
|
||||
greeting = shared.settings['greeting']
|
||||
greeting = shared.settings['greeting']
|
||||
end_of_turn = shared.settings['end_of_turn']
|
||||
|
||||
if Path(f'logs/{shared.character}_persistent.json').exists():
|
||||
@@ -404,9 +422,11 @@ def load_character(character, name1, name2, mode):
|
||||
|
||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
|
||||
|
||||
|
||||
def load_default_history(name1, name2):
|
||||
load_character("None", name1, name2, "chat")
|
||||
|
||||
|
||||
def upload_character(json_file, img, tavern=False):
|
||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||
data = json.loads(json_file)
|
||||
@@ -425,6 +445,7 @@ def upload_character(json_file, img, tavern=False):
|
||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||
return outfile_name
|
||||
|
||||
|
||||
def upload_tavern_character(img, name1, name2):
|
||||
_img = Image.open(io.BytesIO(img))
|
||||
_img.getexif()
|
||||
@@ -433,12 +454,13 @@ def upload_tavern_character(img, name1, name2):
|
||||
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
|
||||
return upload_character(json.dumps(_json), img, tavern=True)
|
||||
|
||||
|
||||
def upload_your_profile_picture(img, name1, name2, mode):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
cache_folder.mkdir()
|
||||
|
||||
if img == None:
|
||||
if img is None:
|
||||
if Path("cache/pfp_me.png").exists():
|
||||
Path("cache/pfp_me.png").unlink()
|
||||
else:
|
||||
|
||||
@@ -9,25 +9,32 @@ state = {}
|
||||
available_extensions = []
|
||||
setup_called = set()
|
||||
|
||||
|
||||
def load_extensions():
|
||||
global state
|
||||
global state, setup_called
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
if name in available_extensions:
|
||||
print(f'Loading the extension "{name}"... ', end='')
|
||||
try:
|
||||
exec(f"import extensions.{name}.script")
|
||||
extension = eval(f"extensions.{name}.script")
|
||||
if extension not in setup_called and hasattr(extension, "setup"):
|
||||
setup_called.add(extension)
|
||||
extension.setup()
|
||||
state[name] = [True, i]
|
||||
print('Ok.')
|
||||
except:
|
||||
print('Fail.')
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# This iterator returns the extensions in the order specified in the command-line
|
||||
def iterator():
|
||||
for name in sorted(state, key=lambda x : state[x][1]):
|
||||
if state[name][0] == True:
|
||||
for name in sorted(state, key=lambda x: state[x][1]):
|
||||
if state[name][0]:
|
||||
yield eval(f"extensions.{name}.script"), name
|
||||
|
||||
|
||||
# Extension functions that map string -> string
|
||||
def apply_extensions(text, typ):
|
||||
for extension, _ in iterator():
|
||||
@@ -39,6 +46,7 @@ def apply_extensions(text, typ):
|
||||
text = extension.bot_prefix_modifier(text)
|
||||
return text
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
@@ -51,14 +59,9 @@ def create_extensions_block():
|
||||
extension.params[param] = shared.settings[_id]
|
||||
|
||||
should_display_ui = False
|
||||
|
||||
# Running setup function
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, "ui"):
|
||||
should_display_ui = True
|
||||
if extension not in setup_called and hasattr(extension, "setup"):
|
||||
setup_called.add(extension)
|
||||
extension.setup()
|
||||
|
||||
# Creating the extension ui elements
|
||||
if should_display_ui:
|
||||
|
||||
@@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
|
||||
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
|
||||
instruct_css = f.read()
|
||||
|
||||
|
||||
def fix_newlines(string):
|
||||
string = string.replace('\n', '\n\n')
|
||||
string = re.sub(r"\n{3,}", "\n\n", string)
|
||||
@@ -31,6 +32,8 @@ def fix_newlines(string):
|
||||
return string
|
||||
|
||||
# This could probably be generalized and improved
|
||||
|
||||
|
||||
def convert_to_markdown(string):
|
||||
string = string.replace('\\begin{code}', '```')
|
||||
string = string.replace('\\end{code}', '```')
|
||||
@@ -38,13 +41,15 @@ def convert_to_markdown(string):
|
||||
string = string.replace('\\end{blockquote}', '')
|
||||
string = re.sub(r"(.)```", r"\1\n```", string)
|
||||
string = fix_newlines(string)
|
||||
return markdown.markdown(string, extensions=['fenced_code'])
|
||||
return markdown.markdown(string, extensions=['fenced_code'])
|
||||
|
||||
|
||||
def generate_basic_html(string):
|
||||
string = convert_to_markdown(string)
|
||||
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
||||
return string
|
||||
|
||||
|
||||
def process_post(post, c):
|
||||
t = post.split('\n')
|
||||
number = t[0].split(' ')[1]
|
||||
@@ -59,6 +64,7 @@ def process_post(post, c):
|
||||
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
||||
return src
|
||||
|
||||
|
||||
def generate_4chan_html(f):
|
||||
posts = []
|
||||
post = ''
|
||||
@@ -84,7 +90,7 @@ def generate_4chan_html(f):
|
||||
posts[i] = f'<div class="op">{posts[i]}</div>\n'
|
||||
else:
|
||||
posts[i] = f'<div class="reply">{posts[i]}</div>\n'
|
||||
|
||||
|
||||
output = ''
|
||||
output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
|
||||
for post in posts:
|
||||
@@ -98,13 +104,15 @@ def generate_4chan_html(f):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def make_thumbnail(image):
|
||||
image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
|
||||
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
|
||||
if image.size[1] > 470:
|
||||
image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_image_cache(path):
|
||||
cache_folder = Path("cache")
|
||||
if not cache_folder.exists():
|
||||
@@ -119,9 +127,10 @@ def get_image_cache(path):
|
||||
|
||||
return image_cache[path][1]
|
||||
|
||||
|
||||
def generate_instruct_html(history):
|
||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
@@ -134,7 +143,7 @@ def generate_instruct_html(history):
|
||||
</div>
|
||||
"""
|
||||
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
continue
|
||||
|
||||
output += f"""
|
||||
@@ -151,6 +160,7 @@ def generate_instruct_html(history):
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||
|
||||
@@ -159,7 +169,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
|
||||
img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
|
||||
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
for i, _row in enumerate(history[::-1]):
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
@@ -178,7 +188,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
</div>
|
||||
"""
|
||||
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
if len(row[0]) == 0: # don't display empty user messages
|
||||
continue
|
||||
|
||||
output += f"""
|
||||
@@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
|
||||
output += "</div>"
|
||||
return output
|
||||
|
||||
|
||||
def generate_chat_html(history, name1, name2):
|
||||
return generate_cai_chat_html(history, name1, name2)
|
||||
|
||||
|
||||
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
|
||||
if mode == "cai-chat":
|
||||
return generate_cai_chat_html(history, name1, name2, reset_cache)
|
||||
|
||||
@@ -50,9 +50,9 @@ class LlamaCppModel:
|
||||
params.top_k = top_k
|
||||
params.temp = temperature
|
||||
params.repeat_penalty = repetition_penalty
|
||||
#params.repeat_last_n = repeat_last_n
|
||||
# params.repeat_last_n = repeat_last_n
|
||||
|
||||
#self.model.params = params
|
||||
# self.model.params = params
|
||||
self.model.add_bos()
|
||||
self.model.update_input(context)
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
'''
|
||||
Based on
|
||||
Based on
|
||||
https://github.com/abetlen/llama-cpp-python
|
||||
|
||||
Documentation:
|
||||
https://abetlen.github.io/llama-cpp-python/
|
||||
'''
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
from modules import shared
|
||||
@@ -31,7 +29,7 @@ class LlamaCppModel:
|
||||
self.model = Llama(**params)
|
||||
|
||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
||||
return result, result
|
||||
return result, result
|
||||
|
||||
def encode(self, string):
|
||||
if type(string) is str:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -16,11 +17,10 @@ import modules.shared as shared
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
local_rank = None
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||
|
||||
local_rank = None
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||
@@ -34,7 +34,7 @@ if shared.args.deepspeed:
|
||||
torch.cuda.set_device(local_rank)
|
||||
deepspeed.init_distributed()
|
||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
@@ -83,7 +83,7 @@ def load_model(model_name):
|
||||
elif shared.args.deepspeed:
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
|
||||
model.module.eval() # Inference
|
||||
model.module.eval() # Inference
|
||||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||
|
||||
# RMKV model (not on HuggingFace)
|
||||
@@ -132,7 +132,7 @@ def load_model(model_name):
|
||||
params["torch_dtype"] = torch.float16
|
||||
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
@@ -140,13 +140,13 @@ def load_model(model_name):
|
||||
max_memory['cpu'] = max_cpu_memory
|
||||
params['max_memory'] = max_memory
|
||||
elif shared.args.auto_devices:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
|
||||
suggestion = round((total_mem-1000) / 1000) * 1000
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
|
||||
suggestion = round((total_mem - 1000) / 1000) * 1000
|
||||
if total_mem - suggestion < 800:
|
||||
suggestion -= 1000
|
||||
suggestion = int(round(suggestion/1000))
|
||||
suggestion = int(round(suggestion / 1000))
|
||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||
|
||||
|
||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
||||
params['max_memory'] = max_memory
|
||||
|
||||
@@ -161,10 +161,10 @@ def load_model(model_name):
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes = model._no_split_modules
|
||||
no_split_module_classes=model._no_split_modules
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
@@ -181,6 +181,23 @@ def load_model(model_name):
|
||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def reload_model():
|
||||
unload_model()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def load_soft_prompt(name):
|
||||
if name == 'None':
|
||||
shared.soft_prompt = False
|
||||
|
||||
@@ -61,6 +61,7 @@ settings = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
@@ -71,7 +72,8 @@ def str2bool(v):
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
|
||||
|
||||
# Basic settings
|
||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||
@@ -145,5 +147,6 @@ if args.cai_chat:
|
||||
print("Warning: --cai-chat is deprecated. Use --chat instead.")
|
||||
args.chat = True
|
||||
|
||||
|
||||
def is_chat():
|
||||
return args.chat
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -12,15 +11,16 @@ from modules.callbacks import (Iteratorize, Stream,
|
||||
_SentinelTokenStoppingCriteria)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.models import local_rank
|
||||
from modules.models import clear_torch_cache, local_rank
|
||||
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
max_length = 2048-tokens
|
||||
max_length = 2048 - tokens
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
return max_length
|
||||
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
@@ -30,7 +30,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:,1:]
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
if shared.args.cpu:
|
||||
return input_ids
|
||||
@@ -44,6 +44,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
else:
|
||||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||
@@ -53,14 +54,17 @@ def decode(output_ids):
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
#filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
|
||||
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
@@ -69,6 +73,8 @@ def fix_gpt4chan(s):
|
||||
return s
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
|
||||
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
@@ -79,6 +85,7 @@ def fix_galactica(s):
|
||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if not shared.is_chat():
|
||||
if 'galactica' in model_name.lower():
|
||||
@@ -92,10 +99,6 @@ def formatted_outputs(reply, model_name):
|
||||
else:
|
||||
return reply
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def set_manual_seed(seed):
|
||||
if seed != -1:
|
||||
@@ -103,9 +106,11 @@ def set_manual_seed(seed):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
set_manual_seed(generate_state['seed'])
|
||||
@@ -115,22 +120,22 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions(question, "input")
|
||||
question = apply_extensions(question, 'input')
|
||||
if shared.args.verbose:
|
||||
print(f"\n\n{question}\n--------------------\n")
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params["token_count"] = generate_state["max_new_tokens"]
|
||||
generate_params['token_count'] = generate_state['max_new_tokens']
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question+reply
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
@@ -139,9 +144,9 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question+reply
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
@@ -150,7 +155,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||
@@ -166,31 +171,30 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||
|
||||
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
|
||||
if not shared.args.flexgen:
|
||||
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params["eos_token_id"] = eos_token_ids
|
||||
generate_params["stopping_criteria"] = stopping_criteria_list
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
if shared.args.no_stream:
|
||||
generate_params["min_length"] = 0
|
||||
generate_params['min_length'] = 0
|
||||
else:
|
||||
for k in ["do_sample", "temperature"]:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params["stop"] = generate_state["eos_token_ids"][-1]
|
||||
generate_params['stop'] = generate_state['eos_token_ids'][-1]
|
||||
if not shared.args.no_stream:
|
||||
generate_params["max_new_tokens"] = 8
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
if shared.args.no_cache:
|
||||
generate_params.update({"use_cache": False})
|
||||
generate_params.update({'use_cache': False})
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({"synced_gpus": True})
|
||||
generate_params.update({'synced_gpus': True})
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
||||
generate_params.update({"inputs": filler_input_ids})
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
generate_params.update({"inputs": input_ids})
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
@@ -205,7 +209,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
@@ -232,7 +236,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
@@ -240,7 +244,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(generate_state['max_new_tokens']//8+1):
|
||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
@@ -250,7 +254,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, "output")
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
@@ -259,10 +263,10 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
||||
generate_params.update({"inputs": filler_input_ids})
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
generate_params.update({"inputs": input_ids})
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
@@ -271,6 +275,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output)-original_tokens
|
||||
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
|
||||
new_tokens = len(output) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})')
|
||||
return
|
||||
|
||||
@@ -19,9 +19,11 @@ CURRENT_STEPS = 0
|
||||
MAX_STEPS = 0
|
||||
CURRENT_GRADIENT_ACCUM = 1
|
||||
|
||||
|
||||
def get_dataset(path: str, ext: str):
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
|
||||
@@ -44,16 +46,16 @@ def create_train_interface():
|
||||
with gr.Tab(label="Formatted Dataset"):
|
||||
with gr.Row():
|
||||
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
|
||||
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
|
||||
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
|
||||
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
|
||||
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
|
||||
|
||||
with gr.Tab(label="Raw Text File"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
|
||||
@@ -67,10 +69,12 @@ def create_train_interface():
|
||||
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
|
||||
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
|
||||
|
||||
|
||||
def do_interrupt():
|
||||
global WANT_INTERRUPT
|
||||
WANT_INTERRUPT = True
|
||||
|
||||
|
||||
class Callbacks(transformers.TrainerCallback):
|
||||
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS, MAX_STEPS
|
||||
@@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
|
||||
if WANT_INTERRUPT:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
|
||||
global CURRENT_STEPS
|
||||
CURRENT_STEPS += 1
|
||||
@@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
|
||||
def clean_path(base_path: str, path: str):
|
||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||
@@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
|
||||
return path
|
||||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
|
||||
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
|
||||
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
|
||||
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
|
||||
@@ -124,7 +131,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
elif not shared.args.load_in_8bit:
|
||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
|
||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||
yield "Cannot input zeroes."
|
||||
@@ -145,10 +152,10 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
print("Loading raw text file dataset...")
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read()
|
||||
tokens = shared.tokenizer.encode(raw_text)
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
|
||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||
for i in range(1, len(tokens)):
|
||||
@@ -197,18 +204,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
else:
|
||||
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
||||
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
|
||||
|
||||
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
print("Getting model ready...")
|
||||
prepare_model_for_int8_training(shared.model)
|
||||
|
||||
|
||||
print("Prepping for training...")
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
# TODO: Should target_modules be configurable?
|
||||
target_modules=[ "q_proj", "v_proj" ],
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
@@ -289,7 +296,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
timer_info = f"`{its:.2f}` it/s"
|
||||
else:
|
||||
timer_info = f"`{1.0/its:.2f}` s/it"
|
||||
total_time_estimate = (1.0/its) * (MAX_STEPS)
|
||||
total_time_estimate = (1.0 / its) * (MAX_STEPS)
|
||||
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
|
||||
|
||||
print("Training complete, saving...")
|
||||
@@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||
print("Training complete!")
|
||||
yield f"Done! LoRA saved to `{lora_name}`"
|
||||
|
||||
|
||||
def split_chunks(arr, step):
|
||||
for i in range(0, len(arr), step):
|
||||
yield arr[i:i + step]
|
||||
|
||||
|
||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||
if '\n' not in chunk:
|
||||
return chunk
|
||||
@@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||
chunk = chunk[:last_newline]
|
||||
return chunk
|
||||
|
||||
|
||||
def format_time(seconds: float):
|
||||
if seconds < 120:
|
||||
return f"`{seconds:.0f}` seconds"
|
||||
|
||||
@@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
||||
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||
chat_js = f.read()
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
@@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
def get_block_name(self):
|
||||
return "button"
|
||||
|
||||
|
||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||
def refresh():
|
||||
refresh_method()
|
||||
|
||||
159
server.py
159
server.py
@@ -15,12 +15,11 @@ import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
from modules import chat, shared, training, ui, api
|
||||
from modules import api, chat, shared, training, ui
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt
|
||||
from modules.text_generation import (clear_torch_cache, generate_reply,
|
||||
stop_everything_event)
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
from modules.text_generation import generate_reply, stop_everything_event
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
@@ -34,15 +33,18 @@ if settings_file is not None:
|
||||
for item in new_settings:
|
||||
shared.settings[item] = new_settings[item]
|
||||
|
||||
|
||||
def get_available_models():
|
||||
if shared.args.flexgen:
|
||||
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
||||
else:
|
||||
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
|
||||
def get_available_presets():
|
||||
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
|
||||
|
||||
|
||||
def get_available_prompts():
|
||||
prompts = []
|
||||
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
||||
@@ -50,10 +52,12 @@ def get_available_prompts():
|
||||
prompts += ['None']
|
||||
return prompts
|
||||
|
||||
|
||||
def get_available_characters():
|
||||
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
|
||||
|
||||
|
||||
def get_available_instruction_templates():
|
||||
path = "characters/instruction-following"
|
||||
paths = []
|
||||
@@ -61,18 +65,18 @@ def get_available_instruction_templates():
|
||||
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
|
||||
|
||||
|
||||
def get_available_extensions():
|
||||
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
||||
|
||||
|
||||
def get_available_softprompts():
|
||||
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||
|
||||
|
||||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
|
||||
def load_model_wrapper(selected_model):
|
||||
if selected_model != shared.model_name:
|
||||
@@ -84,10 +88,12 @@ def load_model_wrapper(selected_model):
|
||||
|
||||
return selected_model
|
||||
|
||||
|
||||
def load_lora_wrapper(selected_lora):
|
||||
add_lora_to_model(selected_lora)
|
||||
return selected_lora
|
||||
|
||||
|
||||
def load_preset_values(preset_menu, state, return_dict=False):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
@@ -118,6 +124,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
zf.extract('meta.json')
|
||||
@@ -130,12 +137,14 @@ def upload_soft_prompt(file):
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def save_prompt(text):
|
||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
|
||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
return f"Saved to prompts/{fname}"
|
||||
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
return ''
|
||||
@@ -146,12 +155,13 @@ def load_prompt(fname):
|
||||
text = text[:-1]
|
||||
return text
|
||||
|
||||
|
||||
def create_prompt_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Column():
|
||||
@@ -161,20 +171,22 @@ def create_prompt_menus():
|
||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
|
||||
def create_model_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
||||
@@ -185,7 +197,7 @@ def create_settings_menus(default_preset):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
|
||||
with gr.Column():
|
||||
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
|
||||
|
||||
@@ -196,12 +208,12 @@ def create_settings_menus(default_preset):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
@@ -209,7 +221,6 @@ def create_settings_menus(default_preset):
|
||||
with gr.Box():
|
||||
gr.Markdown('Contrastive search')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
|
||||
with gr.Box():
|
||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||
with gr.Row():
|
||||
@@ -219,11 +230,10 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button')
|
||||
|
||||
gr.Markdown('Upload a soft prompt (.zip format):')
|
||||
with gr.Row():
|
||||
@@ -233,6 +243,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||
|
||||
|
||||
def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||
cmd_list = vars(shared.args)
|
||||
@@ -251,6 +262,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
|
||||
|
||||
shared.need_restart = True
|
||||
|
||||
|
||||
available_models = get_available_models()
|
||||
available_presets = get_available_presets()
|
||||
available_characters = get_available_characters()
|
||||
@@ -284,7 +296,7 @@ else:
|
||||
for i, model in enumerate(available_models):
|
||||
print(f'{i+1}. {model}')
|
||||
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
|
||||
i = int(input())-1
|
||||
i = int(input()) - 1
|
||||
print()
|
||||
shared.model_name = available_models[i]
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
@@ -297,22 +309,22 @@ if shared.lora_name != "None":
|
||||
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
|
||||
else:
|
||||
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
||||
title ='Text generation web UI'
|
||||
title = 'Text generation web UI'
|
||||
|
||||
|
||||
def create_interface():
|
||||
|
||||
gen_events = []
|
||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||
extensions_module.load_extensions()
|
||||
|
||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
if shared.is_chat():
|
||||
shared.gradio['Chat input'] = gr.State()
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
@@ -342,7 +354,7 @@ def create_interface():
|
||||
shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Tab('Chat history'):
|
||||
@@ -382,55 +394,71 @@ def create_interface():
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||
|
||||
def set_chat_input(textbox):
|
||||
return textbox, ""
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||
gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
|
||||
gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Generate'].click(
|
||||
set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['textbox'].submit(
|
||||
set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
)
|
||||
|
||||
shared.gradio['Replace last reply'].click(
|
||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['Clear history-confirm'].click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
|
||||
lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['Stop'].click(
|
||||
stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['Chat mode'].change(
|
||||
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
|
||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['Instruction templates'].change(
|
||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['upload_chat_history'].upload(
|
||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []).then(
|
||||
chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear history with confirmation
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
|
||||
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
|
||||
|
||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||
|
||||
# Clearing stuff and saving the history
|
||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
||||
|
||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
||||
shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
|
||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
@@ -523,10 +551,12 @@ def create_interface():
|
||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
||||
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
|
||||
shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
|
||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
|
||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
|
||||
|
||||
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
|
||||
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
# Reset interface event
|
||||
shared.gradio['reset_interface'].click(
|
||||
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
|
||||
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
@@ -562,6 +592,7 @@ def create_interface():
|
||||
else:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
|
||||
|
||||
|
||||
create_interface()
|
||||
|
||||
while True:
|
||||
|
||||
Reference in New Issue
Block a user