17 Commits

Author SHA1 Message Date
oobabooga
7b1b1d8e39 Make the code more like PEP8 2023-04-07 00:00:16 -03:00
oobabooga
8e4e8dd741 Unrelated bugs that I noticed and had to fix 2023-04-06 23:33:09 -03:00
oobabooga
7ac7f1bc9a Merge branch 'main' into da3dsoul-main 2023-04-06 23:22:45 -03:00
da3dsoul
7795e087a7 Fix P, V, and E sounding odd. Add Slash to the punctuation list
Also add torch and torchaudio back to the requirements, as silero needs them. Silero's requirements.txt should be everything needed to run the tests
2023-04-06 21:48:28 -04:00
oobabooga
773e1246da Add back the spaces 2023-04-06 20:45:41 -03:00
oobabooga
7e31bc485c Merge branch 'main' into da3dsoul-main 2023-04-06 20:33:49 -03:00
oobabooga
de7fb66d52 Remove torch from requirements 2023-04-06 20:31:04 -03:00
oobabooga
e80d0b0160 Reorder imports 2023-04-06 20:21:13 -03:00
da3dsoul
38258c4471 Remove unused locale import 2023-04-04 16:34:45 -04:00
da3dsoul
1848938f7f Improve Silero Preprocessing to Handle European Numbers and Decimals
Also add a test script to generate audio clips from CLI
2023-04-04 16:24:25 -04:00
da3dsoul
d5f3036687 Improve Silero's Preprocessor to Handle Abbreviations and Initials Better 2023-04-04 14:15:58 -04:00
da3dsoul
695031e902 Fix a typo in the Silero Preprocessor 2023-04-04 00:20:04 -04:00
da3dsoul
90d7915b86 Merge branch 'oobabooga:main' into main 2023-04-04 00:17:03 -04:00
da3dsoul
c05f727ae4 Improve Silero's Preprocessor to Handle Negative Numbers Better 2023-04-04 00:09:50 -04:00
da3dsoul
39063c48bb Improve Silero's Preprocessor to Handle Punctuation and Whitespace Better 2023-04-03 20:24:09 -04:00
da3dsoul
7d4e419dbe Improve Silero's Preprocessor to Handle Roman Numerals and Abbreviations Better 2023-04-03 18:19:25 -04:00
da3dsoul
b2022d0869 Improve Silero's Preprocessor to Handle Numbers and Abbreviations Better 2023-04-03 18:19:24 -04:00
4 changed files with 291 additions and 28 deletions

View File

@@ -1,4 +1,5 @@
ipython
num2words
omegaconf
pydub
PyYAML

View File

@@ -1,14 +1,16 @@
import re
import time
from pathlib import Path
import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.html_generator import chat_html_wrapper
torch._C._jit_set_profiling_mode(False)
params = {
'activate': True,
'speaker': 'en_56',
@@ -26,7 +28,7 @@ current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
@@ -37,26 +39,25 @@ table = str.maketrans({
'"': """,
})
def xmlesc(txt):
return txt.translate(table)
def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device'])
return model
model = load_model()
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
def remove_tts_from_history(name1, name2):
def remove_tts_from_history(name1, name2, mode):
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def toggle_text_in_history(name1, name2):
def toggle_text_in_history(name1, name2, mode):
for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1]
if visible_reply.startswith('<audio'):
@@ -65,7 +66,8 @@ def toggle_text_in_history(name1, name2):
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def input_modifier(string):
"""
@@ -75,12 +77,13 @@ def input_modifier(string):
# Remove autoplay from the last reply
if shared.is_chat() and len(shared.history['internal']) > 0:
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
shared.processing_message = "*Is recording a voice message...*"
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string
def output_modifier(string):
"""
This function is applied to the model outputs.
@@ -98,11 +101,7 @@ def output_modifier(string):
return string
original_string = string
string = remove_surrounded_chars(string)
string = string.replace('"', '')
string = string.replace('', '')
string = string.replace('\n', ' ')
string = string.strip()
string = tts_preprocessor.preprocess(string)
if string == '':
string = '*Empty reply, try regenerating*'
@@ -118,9 +117,10 @@ def output_modifier(string):
string += f'\n\n{original_string}'
shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
@@ -130,38 +130,42 @@ def bot_prefix_modifier(string):
return string
def ui():
# Gradio elements
with gr.Accordion("Silero TTS"):
with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts')
convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
convert_confirm.click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
show_text.change(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)
voice.change(lambda x: params.update({"speaker": x}), voice, None)
v_pitch.change(lambda x: params.update({"voice_pitch": x}), v_pitch, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)
v_speed.change(lambda x: params.update({"voice_speed": x}), v_speed, None)

View File

@@ -0,0 +1,81 @@
import time
from pathlib import Path
import torch
import tts_preprocessor
torch._C._jit_set_profiling_mode(False)
params = {
'activate': True,
'speaker': 'en_49',
'language': 'en',
'model_id': 'v3_en',
'sample_rate': 48000,
'device': 'cpu',
'show_text': True,
'autoplay': True,
'voice_pitch': 'medium',
'voice_speed': 'medium',
}
current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
})
def xmlesc(txt):
return txt.translate(table)
def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device'])
return model
model = load_model()
def output_modifier(string):
"""
This function is applied to the model outputs.
"""
global model, current_params
original_string = string
string = tts_preprocessor.preprocess(string)
processed_string = string
if string == '':
string = '*Empty reply, try regenerating*'
else:
output_file = Path(f'extensions/silero_tts/outputs/test_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
autoplay = 'autoplay' if params['autoplay'] else ''
string = f'<audio src="file/{output_file.as_posix()}" controls {autoplay}></audio>'
if params['show_text']:
string += f'\n\n{original_string}\n\nProcessed:\n{processed_string}'
print(string)
if __name__ == '__main__':
import sys
output_modifier(sys.argv[1])

View File

@@ -0,0 +1,177 @@
import re
from num2words import num2words
# punctuation = r'[\s,.?!/)"\'\]>]“”'
punctuation = r'[\s,.?!/)"\'\]>]'
alphabet_map = {
"A": " Ei ",
"B": " Bee ",
"C": " See ",
"D": " Dee ",
"E": " Eee ",
"F": " Eff ",
"G": " Jee ",
"H": " Eich ",
"I": " Eye ",
"J": " Jay ",
"K": " Kay ",
"L": " El ",
"M": " Emm ",
"N": " Enn ",
"O": " Ohh ",
"P": " Pee ",
"Q": " Queue ",
"R": " Are ",
"S": " Ess ",
"T": " Tee ",
"U": " You ",
"V": " Vee ",
"W": " Double You ",
"X": " Ex ",
"Y": " Why ",
"Z": " Zed " # Zed is weird, as I (da3dsoul) am American, but most of the voice models sound British, so it matches
}
def preprocess(string):
# the order for some of these matter
# For example, you need to remove the commas in numbers before expanding them
string = remove_surrounded_chars(string)
string = string.replace('"', '')
string = string.replace('', '')
string = string.replace('\n', ' ')
string = convert_num_locale(string)
string = replace_negative(string)
string = replace_roman(string)
string = hyphen_range_to(string)
string = num_to_words(string)
# TODO Try to use a ML predictor to expand abbreviations. It's hard, dependent on context, and whether to actually
# try to say the abbreviation or spell it out as I've done below is not agreed upon
# For now, expand abbreviations to pronunciations
# replace_abbreviations adds a lot of unnecessary whitespace to ensure separation
string = replace_abbreviations(string)
# cleanup whitespaces
# remove whitespace before punctuation
string = re.sub(rf'\s+({punctuation})', r'\1', string)
string = string.strip()
# compact whitespace
string = ' '.join(string.split())
return string
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub(r'\*[^*]*?(\*|$)', '', string)
def replace_negative(string):
# handles situations like -5. -5 would become negative 5, which would then be expanded to negative five
return re.sub(rf'(\s)(-)(\d+)({punctuation})', r'\1negative \3\4', string)
def replace_roman(string):
# find a string of roman numerals.
# Only 2 or more, to avoid capturing I and single character abbreviations, like names
pattern = re.compile(rf'\s[IVXLCDM]{{2,}}{punctuation}')
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start + 1] + str(roman_to_int(result[start + 1:end - 1])) + result[end - 1:len(result)]
return result
def roman_to_int(s):
rom_val = {'I': 1, 'V': 5, 'X': 10, 'L': 50, 'C': 100, 'D': 500, 'M': 1000}
int_val = 0
for i in range(len(s)):
if i > 0 and rom_val[s[i]] > rom_val[s[i - 1]]:
int_val += rom_val[s[i]] - 2 * rom_val[s[i - 1]]
else:
int_val += rom_val[s[i]]
return int_val
def hyphen_range_to(text):
pattern = re.compile(r'(\d+)[-](\d+)')
result = pattern.sub(lambda x: x.group(1) + ' to ' + x.group(2), text)
return result
def num_to_words(text):
# 1000 or 10.23
pattern = re.compile(r'\d+\.\d+|\d+')
result = pattern.sub(lambda x: num2words(float(x.group())), text)
return result
def replace_abbreviations(string):
# abbreviations 1 to 4 characters long. It will get things like A and I, but those are pronounced with their letter
pattern = re.compile(rf'(^|[\s("\'\[<])([A-Z]{{1,4}})({punctuation}|$)')
result = string
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + replace_abbreviation(result[start:end]) + result[end:len(result)]
return result
def replace_abbreviation(string):
result = ""
for char in string:
result = match_mapping(char, result)
return result
def match_mapping(char, result):
for mapping in alphabet_map.keys():
if char == mapping:
return result + alphabet_map[char]
return result + char
def convert_num_locale(text):
# This detects locale and converts it to American without comma separators
pattern = re.compile(r'(?:\s|^)\d{1,3}(?:\.\d{3})*(?:,\d+)?(?:\s|$)')
result = text
while True:
match = pattern.search(result)
if match is None:
break
start = match.start()
end = match.end()
result = result[0:start] + result[start:end].replace('.', '').replace(',', '.') + result[end:len(result)]
# removes comma separators from existing American numbers
pattern = re.compile(r'(\d),(\d)')
result = pattern.sub(r'\1\2', result)
return result
def __main__(args):
print(preprocess(args[1]))
if __name__ == "__main__":
import sys
__main__(sys.argv)