import argparse import functools import importlib.util import json import os from pathlib import Path import re import time try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*_args, **_kwargs): def _decorator(func): return func return _decorator spaces = _SpacesFallback() import gradio as gr import numpy as np import torch from transformers import AutoModel, AutoProcessor # Disable the broken cuDNN SDPA backend torch.backends.cuda.enable_cudnn_sdp(False) # Keep these enabled as fallbacks torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) MODEL_PATH = "OpenMOSS-Team/MOSS-VoiceGenerator" DEFAULT_ATTN_IMPLEMENTATION = "auto" DEFAULT_MAX_NEW_TOKENS = 4096 PRELOAD_ENV_VAR = "MOSS_VOICE_GENERATOR_PRELOAD_AT_STARTUP" EXAMPLE_TEXTS_JSONL_PATH = Path(__file__).resolve().parent / "text" / "moss_voice_generator_example_texts.jsonl" def _parse_example_id(example_id: str) -> tuple[str, int] | None: matched = re.fullmatch(r"(zh|en)/(\d+)", (example_id or "").strip()) if matched is None: return None return matched.group(1), int(matched.group(2)) def build_example_rows() -> list[tuple[str, str, str]]: rows: list[tuple[str, int, str, str]] = [] with open(EXAMPLE_TEXTS_JSONL_PATH, "r", encoding="utf-8") as f: for line in f: if not line.strip(): continue sample = json.loads(line) parsed = _parse_example_id(sample.get("id", "")) if parsed is None: continue language, index = parsed instruction = str(sample.get("instruction", "")).strip() text = str(sample.get("text", "")).strip() rows.append((language, index, instruction, text)) language_order = {"zh": 0, "en": 1} rows.sort(key=lambda item: (language_order.get(item[0], 99), item[1])) return [(f"{language}/{index}", instruction, text) for language, index, instruction, text in rows] EXAMPLE_ROWS = build_example_rows() def apply_example_selection(evt: gr.SelectData): if evt is None or evt.index is None: return gr.update(), gr.update() if isinstance(evt.index, (tuple, list)): row_idx = int(evt.index[0]) else: row_idx = int(evt.index) if row_idx < 0 or row_idx >= len(EXAMPLE_ROWS): return gr.update(), gr.update() _, instruction_value, text_value = EXAMPLE_ROWS[row_idx] return instruction_value, text_value def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None: requested_norm = (requested or "").strip().lower() if requested_norm in {"none"}: return None if requested_norm not in {"", "auto"}: return requested # Prefer FlashAttention 2 when package + device conditions are met. if ( device.type == "cuda" and importlib.util.find_spec("flash_attn") is not None and dtype in {torch.float16, torch.bfloat16} ): major, _ = torch.cuda.get_device_capability(device) if major >= 8: return "flash_attention_2" # CUDA fallback: use PyTorch SDPA kernels. if device.type == "cuda": return "sdpa" # CPU fallback. return "eager" @functools.lru_cache(maxsize=1) def load_backend(model_path: str, device_str: str, attn_implementation: str): device = torch.device(device_str if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 resolved_attn_implementation = resolve_attn_implementation( requested=attn_implementation, device=device, dtype=dtype, ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, normalize_inputs=True, ) if hasattr(processor, "audio_tokenizer"): processor.audio_tokenizer = processor.audio_tokenizer.to(device) processor.audio_tokenizer.eval() model_kwargs = { "trust_remote_code": True, "torch_dtype": dtype, } if resolved_attn_implementation: model_kwargs["attn_implementation"] = resolved_attn_implementation model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device) model.eval() sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000)) return model, processor, device, sample_rate def build_conversation(text: str, instruction: str, processor): text = (text or "").strip() instruction = (instruction or "").strip() if not text: raise ValueError("Please enter text to synthesize.") if not instruction: raise ValueError("Please enter a voice instruction.") return [[processor.build_user_message(text=text, instruction=instruction)]] @spaces.GPU(duration=180) def run_inference( text: str, instruction: str, temperature: float, top_p: float, top_k: int, repetition_penalty: float, max_new_tokens: int, model_path: str, device: str, attn_implementation: str, ): started_at = time.monotonic() model, processor, torch_device, sample_rate = load_backend( model_path=model_path, device_str=device, attn_implementation=attn_implementation, ) conversations = build_conversation( text=text, instruction=instruction, processor=processor, ) batch = processor(conversations, mode="generation") input_ids = batch["input_ids"].to(torch_device) attention_mask = batch["attention_mask"].to(torch_device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=int(max_new_tokens), audio_temperature=float(temperature), audio_top_p=float(top_p), audio_top_k=int(top_k), audio_repetition_penalty=float(repetition_penalty), ) messages = processor.decode(outputs) if not messages or messages[0] is None: raise RuntimeError("The model did not return a decodable audio result.") audio = messages[0].audio_codes_list[0] if isinstance(audio, torch.Tensor): audio_np = audio.detach().float().cpu().numpy() else: audio_np = np.asarray(audio, dtype=np.float32) if audio_np.ndim > 1: audio_np = audio_np.reshape(-1) audio_np = audio_np.astype(np.float32, copy=False) elapsed = time.monotonic() - started_at status = ( f"Done | elapsed: {elapsed:.2f}s | " f"max_new_tokens={int(max_new_tokens)}, " f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, " f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}" ) return (sample_rate, audio_np), status def build_demo(args: argparse.Namespace): custom_css = """ :root { --bg: #f6f7f8; --panel: #ffffff; --ink: #111418; --muted: #4d5562; --line: #e5e7eb; --accent: #0f766e; } .gradio-container { background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%); color: var(--ink); } .app-card { border: 1px solid var(--line); border-radius: 16px; background: var(--panel); padding: 14px; } .app-title { font-size: 22px; font-weight: 700; margin-bottom: 6px; letter-spacing: 0.2px; } .app-subtitle { color: var(--muted); font-size: 14px; margin-bottom: 8px; } #output_audio { padding-bottom: 12px; margin-bottom: 8px; overflow: hidden !important; } #output_audio > .wrap { overflow: hidden !important; } #output_audio audio { margin-bottom: 6px; } #run-btn { background: var(--accent); border: none; } """ with gr.Blocks(title="MOSS-VoiceGenerator Demo", css=custom_css) as demo: gr.Markdown( """
MOSS-VoiceGenerator
Design expressive voices from instruction + text without reference audio.
""" ) with gr.Row(equal_height=False): with gr.Column(scale=3): instruction = gr.Textbox( label="Voice Instruction", lines=5, placeholder="Example: Warm, gentle female narrator voice with calm pacing and clear articulation.", ) text = gr.Textbox( label="Text", lines=8, placeholder="Enter the text content to synthesize with the instruction-defined voice.", ) with gr.Accordion("Sampling Parameters (Audio)", open=True): temperature = gr.Slider( minimum=0.1, maximum=3.0, step=0.05, value=1.5, label="temperature", ) top_p = gr.Slider( minimum=0.1, maximum=1.0, step=0.01, value=0.6, label="top_p", ) top_k = gr.Slider( minimum=1, maximum=200, step=1, value=50, label="top_k", ) repetition_penalty = gr.Slider( minimum=0.8, maximum=2.0, step=0.05, value=1.1, label="repetition_penalty", ) max_new_tokens = gr.Slider( minimum=256, maximum=8192, step=128, value=DEFAULT_MAX_NEW_TOKENS, label="max_new_tokens", ) run_btn = gr.Button("Generate Voice", variant="primary", elem_id="run-btn") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio", type="numpy", elem_id="output_audio") status = gr.Textbox(label="Status", lines=4, interactive=False) examples_table = gr.Dataframe( headers=["Voice Instruction", "Example Text"], value=[[example_instruction, example_text] for _, example_instruction, example_text in EXAMPLE_ROWS], datatype=["str", "str"], row_count=(len(EXAMPLE_ROWS), "fixed"), col_count=(2, "fixed"), interactive=False, wrap=True, label="Examples (click a row to fill inputs)", ) examples_table.select( fn=apply_example_selection, inputs=[], outputs=[instruction, text], ) run_btn.click( fn=run_inference, inputs=[ text, instruction, temperature, top_p, top_k, repetition_penalty, max_new_tokens, gr.State(args.model_path), gr.State(args.device), gr.State(args.attn_implementation), ], outputs=[output_audio, status], ) return demo def resolve_runtime_attn(args: argparse.Namespace) -> argparse.Namespace: runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu") runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32 args.attn_implementation = resolve_attn_implementation( requested=args.attn_implementation, device=runtime_device, dtype=runtime_dtype, ) or "none" return args def parse_bool_env(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "y", "on"} def parse_port(value: str | None, default: int) -> int: if not value: return default try: return int(value) except ValueError: return default def main(): parser = argparse.ArgumentParser(description="MOSS-VoiceGenerator Gradio Demo") parser.add_argument("--model_path", type=str, default=MODEL_PATH) parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument( "--port", type=int, default=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))), ) parser.add_argument("--share", action="store_true") args = parser.parse_args() args.host = os.getenv("GRADIO_SERVER_NAME", args.host) args.port = parse_port(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT")), args.port) args = resolve_runtime_attn(args) print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True) preload_enabled = parse_bool_env(PRELOAD_ENV_VAR, default=not bool(os.getenv("SPACE_ID"))) if preload_enabled: preload_started_at = time.monotonic() print( f"[Startup] Preloading backend: model={args.model_path}, device={args.device}, attn={args.attn_implementation}", flush=True, ) load_backend( model_path=args.model_path, device_str=args.device, attn_implementation=args.attn_implementation, ) print( f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s", flush=True, ) else: print( f"[Startup] Skipping preload (set {PRELOAD_ENV_VAR}=1 to enable).", flush=True, ) demo = build_demo(args) demo.queue(max_size=16, default_concurrency_limit=1).launch( server_name=args.host, server_port=args.port, share=args.share, ssr_mode=False, ) if __name__ == "__main__": main()