abi / chat_ai.py
alex16052G's picture
Update chat_ai.py
6afdbbb verified
import re
import tempfile
import gradio as gr
import numpy as np
import soundfile as sf
import torchaudio
from cached_path import cached_path
from transformers import AutoModelForCausalLM, AutoTokenizer
from num2words import num2words
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
from f5_tts.model import DiT, UNetT
from f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
)
vocoder = load_vocoder()
# Cargar modelos
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
F5TTS_ema_model = load_model(
DiT, F5TTS_model_cfg, str(cached_path("hf://jpgallegoar/F5-Spanish/model_1200000.safetensors"))
)
chat_model_state = None
chat_tokenizer_state = None
@gpu_decorator
def generate_response(messages, model, tokenizer):
"""Generar respuesta usando Qwen."""
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
def traducir_numero_a_texto(texto):
texto_separado = re.sub(r'([A-Za-z])(\d)', r'\1 \2', texto)
texto_separado = re.sub(r'(\d)([A-Za-z])', r'\1 \2', texto_separado)
def reemplazar_numero(match):
numero = match.group()
return num2words(int(numero), lang='es')
texto_traducido = re.sub(r'\b\d+\b', reemplazar_numero, texto_separado)
return texto_traducido
@gpu_decorator
def infer(
ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
):
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
ema_model = F5TTS_ema_model
if not gen_text.startswith(" "):
gen_text = " " + gen_text
if not gen_text.endswith(". "):
gen_text += ". "
gen_text = gen_text.lower()
gen_text = traducir_numero_a_texto(gen_text)
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
# Remover silencios
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
final_wave = final_wave.squeeze().cpu().numpy()
# Guardar espectrograma
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path
with gr.Blocks() as app_tts:
gr.Markdown("# TTS por Lotes")
ref_audio_input = gr.Audio(label="Audio de Referencia", type="filepath")
gen_text_input = gr.Textbox(label="Texto para Generar", lines=10)
model_choice = gr.Radio(choices=["F5-TTS"], label="Seleccionar Modelo TTS", value="F5-TTS")
generate_btn = gr.Button("Sintetizar", variant="primary")
with gr.Accordion("Configuraciones Avanzadas", open=False):
ref_text_input = gr.Textbox(
label="Texto de Referencia",
info="Deja en blanco para transcribir automáticamente el audio de referencia. Si ingresas texto, sobrescribirá la transcripción automática.",
lines=2,
)
remove_silence = gr.Checkbox(
label="Eliminar Silencios",
info="El modelo tiende a producir silencios, especialmente en audios más largos. Podemos eliminar manualmente los silencios si es necesario. Ten en cuenta que esta es una característica experimental y puede producir resultados extraños. Esto también aumentará el tiempo de generación.",
value=False,
)
speed_slider = gr.Slider(
label="Velocidad",
minimum=0.3,
maximum=2.0,
value=1.0,
step=0.1,
info="Ajusta la velocidad del audio.",
)
cross_fade_duration_slider = gr.Slider(
label="Duración del Cross-Fade (s)",
minimum=0.0,
maximum=1.0,
value=0.15,
step=0.01,
info="Establece la duración del cross-fade entre clips de audio.",
)
audio_output = gr.Audio(label="Audio Sintetizado")
spectrogram_output = gr.Image(label="Espectrograma")
generate_btn.click(
infer,
inputs=[
ref_audio_input,
ref_text_input,
gen_text_input,
model_choice,
remove_silence,
cross_fade_duration_slider,
speed_slider,
],
outputs=[audio_output, spectrogram_output],
)
with gr.Blocks() as app:
gr.Markdown(
"""
# Spanish-F5
Esta es una interfaz web para F5 TTS, con un finetuning para poder hablar en castellano.
"""
)
gr.TabbedInterface(
[app_tts],
["TTS"],
)
if __name__ == "__main__":
app.queue().launch()