import argparse import functools import importlib.util import os from pathlib import Path import re import time import orjson try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*_args, **_kwargs): def _decorator(func): return func return _decorator spaces = _SpacesFallback() import gradio as gr import numpy as np import torch from transformers import AutoModel, AutoProcessor # Disable the broken cuDNN SDPA backend torch.backends.cuda.enable_cudnn_sdp(False) # Keep these enabled as fallbacks torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) MODEL_PATH = "OpenMOSS-Team/MOSS-TTS" DEFAULT_ATTN_IMPLEMENTATION = "auto" DEFAULT_MAX_NEW_TOKENS = 4096 PRELOAD_ENV_VAR = "MOSS_TTS_PRELOAD_AT_STARTUP" CONTINUATION_NOTICE = ( "Continuation mode is active. Make sure the reference audio transcript is prepended to the input text." ) MODE_CLONE = "Clone" MODE_CONTINUE = "Continuation" MODE_CONTINUE_CLONE = "Continuation + Clone" ZH_TOKENS_PER_CHAR = 3.098411951313033 EN_TOKENS_PER_CHAR = 0.8673376262755219 REFERENCE_AUDIO_DIR = Path(__file__).resolve().parent / "assets" / "audio" EXAMPLE_TEXTS_JSONL_PATH = Path(__file__).resolve().parent / "assets" / "text" / "moss_tts_example_texts.jsonl" def _parse_example_id(example_id: str) -> tuple[str, int] | None: matched = re.fullmatch(r"(zh|en)/(\d+)", (example_id or "").strip()) if matched is None: return None return matched.group(1), int(matched.group(2)) def _resolve_reference_audio_path(language: str, index: int) -> Path | None: stem_candidates = [f"reference_{language}_{index}"] for stem in stem_candidates: for ext in (".wav", ".mp3"): audio_path = REFERENCE_AUDIO_DIR / f"{stem}{ext}" if audio_path.exists(): return audio_path return None def build_example_rows() -> list[tuple[str, str, str]]: rows: list[tuple[str, str, str]] = [] with open(EXAMPLE_TEXTS_JSONL_PATH, "rb") as f: for line in f: if not line.strip(): continue sample = orjson.loads(line) parsed = _parse_example_id(sample.get("id", "")) if parsed is None: continue language, index = parsed text = str(sample.get("text", "")).strip() audio_path = _resolve_reference_audio_path(language, index) if audio_path is None: continue rows.append((sample['role'], str(audio_path), text)) return rows EXAMPLE_ROWS = build_example_rows() @functools.lru_cache(maxsize=1) def load_backend(model_path: str, device_str: str, attn_implementation: str): device = torch.device(device_str if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 resolved_attn_implementation = resolve_attn_implementation( requested=attn_implementation, device=device, dtype=dtype, ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, ) if hasattr(processor, "audio_tokenizer"): processor.audio_tokenizer = processor.audio_tokenizer.to(device) processor.audio_tokenizer.eval() model_kwargs = { "trust_remote_code": True, "torch_dtype": dtype, } if resolved_attn_implementation: model_kwargs["attn_implementation"] = resolved_attn_implementation model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device) model.eval() sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000)) return model, processor, device, sample_rate def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None: requested_norm = (requested or "").strip().lower() if requested_norm in {"none"}: return None if requested_norm not in {"", "auto"}: return requested # Prefer FlashAttention 2 when package + device conditions are met. if ( device.type == "cuda" and importlib.util.find_spec("flash_attn") is not None and dtype in {torch.float16, torch.bfloat16} ): major, _ = torch.cuda.get_device_capability(device) if major >= 8: return "flash_attention_2" # CUDA fallback: use PyTorch SDPA kernels. if device.type == "cuda": return "sdpa" # CPU fallback. return "eager" def detect_text_language(text: str) -> str: zh_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) en_chars = len(re.findall(r"[A-Za-z]", text)) if zh_chars == 0 and en_chars == 0: return "en" return "zh" if zh_chars >= en_chars else "en" def supports_duration_control(mode_with_reference: str) -> bool: return mode_with_reference not in {MODE_CONTINUE, MODE_CONTINUE_CLONE} def estimate_duration_tokens(text: str) -> tuple[str, int, int, int]: normalized = text or "" effective_len = max(len(normalized), 1) language = detect_text_language(normalized) factor = ZH_TOKENS_PER_CHAR if language == "zh" else EN_TOKENS_PER_CHAR default_tokens = max(1, int(effective_len * factor)) min_tokens = max(1, int(default_tokens * 0.5)) max_tokens = max(min_tokens, int(default_tokens * 1.5)) return language, default_tokens, min_tokens, max_tokens def update_duration_controls( enabled: bool, text: str, current_tokens: float | int | None, mode_with_reference: str, ): if not supports_duration_control(mode_with_reference): return ( gr.update(visible=False), "Duration control is disabled for Continuation modes.", gr.update(value=False, interactive=False), ) checkbox_update = gr.update(interactive=True) if not enabled: return gr.update(visible=False), "Duration control is disabled.", checkbox_update language, default_tokens, min_tokens, max_tokens = estimate_duration_tokens(text) # Slider is initialized with value=1 as a placeholder; treat it as "unset" # so first-time estimation uses the computed default instead of clamping to min. if current_tokens is None or int(current_tokens) == 1: slider_value = default_tokens else: slider_value = int(current_tokens) slider_value = max(min_tokens, min(max_tokens, slider_value)) language_label = "Chinese" if language == "zh" else "English" hint = ( f"Duration control enabled | detected language: {language_label} | " f"default={default_tokens}, range=[{min_tokens}, {max_tokens}]" ) return ( gr.update( visible=True, minimum=min_tokens, maximum=max_tokens, value=slider_value, step=1, ), hint, checkbox_update, ) def build_conversation( text: str, reference_audio: str | None, mode_with_reference: str, expected_tokens: int | None, processor, ): text = (text or "").strip() if not text: raise ValueError("Please enter text to synthesize.") user_kwargs = {"text": text} if expected_tokens is not None: user_kwargs["tokens"] = int(expected_tokens) if not reference_audio: conversations = [[processor.build_user_message(**user_kwargs)]] return conversations, "generation", "Direct Generation" if mode_with_reference == MODE_CLONE: clone_kwargs = dict(user_kwargs) clone_kwargs["reference"] = [reference_audio] conversations = [[processor.build_user_message(**clone_kwargs)]] return conversations, "generation", MODE_CLONE if mode_with_reference == MODE_CONTINUE: conversations = [ [ processor.build_user_message(**user_kwargs), processor.build_assistant_message(audio_codes_list=[reference_audio]), ] ] return conversations, "continuation", MODE_CONTINUE continue_clone_kwargs = dict(user_kwargs) continue_clone_kwargs["reference"] = [reference_audio] conversations = [ [ processor.build_user_message(**continue_clone_kwargs), processor.build_assistant_message(audio_codes_list=[reference_audio]), ] ] return conversations, "continuation", MODE_CONTINUE_CLONE def render_mode_hint(reference_audio: str | None, mode_with_reference: str): if not reference_audio: return "Current mode: **Direct Generation** (no reference audio uploaded)" if mode_with_reference == MODE_CLONE: return "Current mode: **Clone** (speaker timbre will be cloned from the reference audio)" return f"Current mode: **{mode_with_reference}** \n> {CONTINUATION_NOTICE}" def apply_example_selection( mode_with_reference: str, duration_control_enabled: bool, duration_tokens: int, evt: gr.SelectData, ): if evt is None or evt.index is None: return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() if isinstance(evt.index, (tuple, list)): row_idx = int(evt.index[0]) else: row_idx = int(evt.index) if row_idx < 0 or row_idx >= len(EXAMPLE_ROWS): return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() _, audio_path, example_text = EXAMPLE_ROWS[row_idx] duration_slider_update, duration_hint, duration_checkbox_update = update_duration_controls( duration_control_enabled, example_text, duration_tokens, mode_with_reference, ) return ( audio_path, example_text, render_mode_hint(audio_path, mode_with_reference), duration_slider_update, duration_hint, duration_checkbox_update, ) @spaces.GPU(duration=180) def run_inference( text: str, reference_audio: str | None, mode_with_reference: str, duration_control_enabled: bool, duration_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float, model_path: str, device: str, attn_implementation: str, max_new_tokens: int, ): started_at = time.monotonic() model, processor, torch_device, sample_rate = load_backend( model_path=model_path, device_str=device, attn_implementation=attn_implementation, ) duration_enabled = bool(duration_control_enabled and supports_duration_control(mode_with_reference)) expected_tokens = int(duration_tokens) if duration_enabled else None conversations, mode, mode_name = build_conversation( text=text, reference_audio=reference_audio, mode_with_reference=mode_with_reference, expected_tokens=expected_tokens, processor=processor, ) batch = processor(conversations, mode=mode) input_ids = batch["input_ids"].to(torch_device) attention_mask = batch["attention_mask"].to(torch_device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=int(max_new_tokens), audio_temperature=float(temperature), audio_top_p=float(top_p), audio_top_k=int(top_k), audio_repetition_penalty=float(repetition_penalty), ) messages = processor.decode(outputs) if not messages or messages[0] is None: raise RuntimeError("The model did not return a decodable audio result.") audio = messages[0].audio_codes_list[0] if isinstance(audio, torch.Tensor): audio_np = audio.detach().float().cpu().numpy() else: audio_np = np.asarray(audio, dtype=np.float32) if audio_np.ndim > 1: audio_np = audio_np.reshape(-1) audio_np = audio_np.astype(np.float32, copy=False) elapsed = time.monotonic() - started_at status = ( f"Done | mode: {mode_name} | elapsed: {elapsed:.2f}s | " f"max_new_tokens={int(max_new_tokens)}, " f"expected_tokens={expected_tokens if expected_tokens is not None else 'off'}, " f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, " f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}" ) return (sample_rate, audio_np), status def build_demo(args: argparse.Namespace): custom_css = """ :root { --bg: #f6f7f8; --panel: #ffffff; --ink: #111418; --muted: #4d5562; --line: #e5e7eb; --accent: #0f766e; } .gradio-container { background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%); color: var(--ink); } .app-card { border: 1px solid var(--line); border-radius: 16px; background: var(--panel); padding: 14px; } .app-title { font-size: 22px; font-weight: 700; margin-bottom: 6px; letter-spacing: 0.2px; } .app-subtitle { color: var(--muted); font-size: 14px; margin-bottom: 8px; } #output_audio { padding-bottom: 12px; margin-bottom: 8px; overflow: hidden !important; } #output_audio > .wrap { overflow: hidden !important; } #output_audio audio { margin-bottom: 6px; } #run-btn { background: var(--accent); border: none; } """ with gr.Blocks(title="MOSS-TTS Demo", css=custom_css) as demo: gr.Markdown( """