|
import os |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
import torchaudio |
|
from generator import Segment, load_csm_1b |
|
from huggingface_hub import hf_hub_download, login |
|
from watermarking import watermark |
|
import whisper |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import logging |
|
from transformers import GenerationConfig |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
try: |
|
api_key = os.getenv("HF_TOKEN") |
|
if not api_key: |
|
raise ValueError("HF_TOKEN not found in environment variables.") |
|
login(token=api_key) |
|
|
|
CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" "))) |
|
if not CSM_1B_HF_WATERMARK: |
|
raise ValueError("WATERMARK_KEY not found or invalid in environment variables.") |
|
|
|
gpu_timeout = int(os.getenv("GPU_TIMEOUT", 120)) |
|
except (ValueError, TypeError) as e: |
|
logging.error(f"Configuration error: {e}") |
|
raise |
|
|
|
SPACE_INTRO_TEXT = """ |
|
# Sesame CSM 1B - Conversational Demo |
|
|
|
This demo allows you to have a conversation with Sesame CSM 1B, leveraging Whisper for speech-to-text and Gemma for generating responses. This is an experimental integration and may require significant resources. |
|
|
|
*Disclaimer: This demo relies on several large models. Expect longer processing times, and potential resource limitations.* |
|
""" |
|
|
|
|
|
|
|
|
|
SPEAKER_ID = 0 |
|
MAX_CONTEXT_SEGMENTS = 1 |
|
MAX_GEMMA_LENGTH = 150 |
|
|
|
|
|
conversation_history = [] |
|
|
|
|
|
|
|
def transcribe_audio(audio_path: str, whisper_model) -> str: |
|
try: |
|
audio = whisper.load_audio(audio_path) |
|
audio = whisper.pad_or_trim(audio) |
|
result = whisper_model.transcribe(audio) |
|
return result["text"] |
|
except Exception as e: |
|
logging.error(f"Whisper transcription error: {e}") |
|
return "Error: Could not transcribe audio." |
|
|
|
def generate_response(text: str, model_gemma, tokenizer_gemma, device) -> str: |
|
try: |
|
|
|
messages = [{"role": "user", "content": text}] |
|
input = tokenizer_gemma.apply_chat_template(messages, return_tensors="pt").to(device) |
|
generation_config = GenerationConfig( |
|
max_new_tokens=MAX_GEMMA_LENGTH, |
|
early_stopping=True, |
|
) |
|
|
|
generated_output = model_gemma.generate(input, generation_config=generation_config) |
|
return tokenizer_gemma.decode(generated_output[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logging.error(f"Gemma response generation error: {e}") |
|
return "I'm sorry, I encountered an error generating a response." |
|
|
|
def load_audio(audio_path: str, generator) -> torch.Tensor: |
|
try: |
|
audio_tensor, sample_rate = torchaudio.load(audio_path) |
|
audio_tensor = audio_tensor.mean(dim=0) |
|
if sample_rate != generator.sample_rate: |
|
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=sample_rate, new_freq=generator.sample_rate) |
|
return audio_tensor |
|
except Exception as e: |
|
logging.error(f"Audio loading error: {e}") |
|
raise gr.Error("Could not load or process the audio file.") from e |
|
|
|
def clear_history(): |
|
global conversation_history |
|
conversation_history = [] |
|
logging.info("Conversation history cleared.") |
|
return "Conversation history cleared." |
|
|
|
|
|
|
|
@spaces.GPU(duration=gpu_timeout) |
|
def infer(user_audio) -> tuple[int, np.ndarray]: |
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA is available! Device count: {torch.cuda.device_count()}") |
|
print(f"CUDA device name: {torch.cuda.get_device_name(0)}") |
|
print(f"CUDA version: {torch.version.cuda}") |
|
device = "cuda" |
|
else: |
|
print("CUDA is NOT available. Using CPU.") |
|
device = "cpu" |
|
|
|
try: |
|
|
|
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt") |
|
generator = load_csm_1b(model_path, device) |
|
logging.info("Sesame CSM 1B loaded successfully.") |
|
|
|
whisper_model = whisper.load_model("small.en", device=device) |
|
logging.info("Whisper model loaded successfully.") |
|
|
|
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") |
|
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it").to(device) |
|
logging.info("Gemma 3 1B pt model loaded successfully.") |
|
|
|
if not user_audio: |
|
raise ValueError("No audio input received.") |
|
return _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) |
|
except Exception as e: |
|
logging.exception(f"Inference error: {e}") |
|
raise gr.Error(f"An error occurred during processing: {e}") |
|
|
|
def _infer(user_audio, generator, whisper_model, tokenizer_gemma, model_gemma, device) -> tuple[int, np.ndarray]: |
|
global conversation_history |
|
|
|
try: |
|
user_text = transcribe_audio(user_audio, whisper_model) |
|
logging.info(f"User: {user_text}") |
|
|
|
ai_text = generate_response(user_text, model_gemma, tokenizer_gemma, device) |
|
logging.info(f"AI: {ai_text}") |
|
|
|
try: |
|
ai_audio = generator.generate( |
|
text=ai_text, |
|
speaker=SPEAKER_ID, |
|
context=conversation_history, |
|
max_audio_length_ms=10_000, |
|
) |
|
logging.info("Audio generated successfully.") |
|
except Exception as e: |
|
logging.error(f"Sesame response generation error: {e}") |
|
raise gr.Error(f"Sesame response generation error: {e}") |
|
|
|
|
|
user_segment = Segment(speaker = 1, text = user_text, audio = load_audio(user_audio, generator)) |
|
ai_segment = Segment(speaker = SPEAKER_ID, text = ai_text, audio = ai_audio) |
|
conversation_history.append(user_segment) |
|
conversation_history.append(ai_segment) |
|
|
|
if len(conversation_history) > MAX_CONTEXT_SEGMENTS: |
|
conversation_history.pop(0) |
|
|
|
audio_tensor, wm_sample_rate = watermark( |
|
generator._watermarker, ai_audio, generator.sample_rate, CSM_1B_HF_WATERMARK |
|
) |
|
audio_tensor = torchaudio.functional.resample( |
|
audio_tensor, orig_freq=wm_sample_rate, new_freq=generator.sample_rate |
|
) |
|
|
|
ai_audio_array = (audio_tensor * 32768).to(torch.int16).cpu().numpy() |
|
return generator.sample_rate, ai_audio_array |
|
|
|
except Exception as e: |
|
logging.exception(f"Error in _infer: {e}") |
|
raise gr.Error(f"An error occurred during processing: {e}") |
|
|
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown(SPACE_INTRO_TEXT) |
|
audio_input = gr.Audio(label="Your Input", type="filepath") |
|
audio_output = gr.Audio(label="AI Response") |
|
clear_button = gr.Button("Clear Conversation History") |
|
status_display = gr.Textbox(label="Status", visible=False) |
|
|
|
btn = gr.Button("Generate Response") |
|
btn.click(infer, inputs=[audio_input], outputs=[audio_output]) |
|
clear_button.click(clear_history, outputs=[status_display]) |
|
|
|
app.launch(ssr_mode=False, share=True) |