rulerman commited on
Commit
15d2e10
·
verified ·
1 Parent(s): c76414e
Files changed (3) hide show
  1. README.md +258 -3
  2. modeling_moss_tts.py +35 -20
  3. processing_moss_tts.py +32 -56
README.md CHANGED
@@ -1,3 +1,258 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ # MOSS-TTS Family
5
+
6
+ ## Overview
7
+ MOSS‑TTS Family is an open‑source **speech and sound generation model family** from [MOSI.AI](https://mosi.cn/#hero) and the [OpenMOSS team](https://www.open-moss.com/). It is designed for **high‑fidelity**, **high‑expressiveness**, and **complex real‑world scenarios**, covering stable long‑form speech, multi‑speaker dialogue, voice/character design, environmental sound effects, and real‑time streaming TTS.
8
+
9
+
10
+ ## Introduction
11
+
12
+ <p align="center">
13
+ <img src="./assets/moss_tts_family.jpeg" width="85%" />
14
+ </p>
15
+
16
+ When a single piece of audio needs to **sound like a real person**, **pronounce every word accurately**, **switch speaking styles across content**, **remain stable over tens of minutes**, and **support dialogue, role‑play, and real‑time interaction**, a single TTS model is often not enough. The **MOSS‑TTS Family** breaks the workflow into five production‑ready models that can be used independently or composed into a complete pipeline.
17
+
18
+ - **MOSS‑TTS**: MOSS-TTS is the flagship, production-ready Text-to-Speech foundation model in the MOSS-TTS Family, built to ship, scale, and deliver real-world voice applications beyond demos. It provides high-fidelity zero-shot voice cloning as the core capability, along with ultra-long speech generation, token-level duration control, multilingual and code-switched synthesis, and fine-grained Pinyin/phoneme pronunciation control. Together, these features make it a robust base model for scalable narration, dubbing, and voice-driven products.
19
+ - **MOSS‑TTSD**: MOSS-TTSD is a production-oriented long-form spoken dialogue generation model for creating highly expressive, multi-party conversational audio at scale. It supports continuous long-duration generation, flexible multi-speaker turn-taking control, and zero-shot voice cloning from short reference audio, enabling natural conversations with rich interaction dynamics. It is designed for real-world long-form content such as podcasts, audiobooks, commentary, dubbing, and entertainment dialogue.
20
+ - **MOSS‑VoiceGenerator**: MOSS-VoiceGenerator is an open-source voice design system that generates speaker timbres directly from free-form text descriptions, enabling fast creation of voices for characters, personalities, and emotions—without requiring reference audio. It unifies timbre design, style control, and content synthesis in a single instruction-driven model, producing high-fidelity, emotionally expressive speech that feels naturally human. It can be used standalone for creative production, or as a voice design layer that improves integration and usability for downstream TTS systems.
21
+ - **MOSS‑SoundEffect**: MOSS-SoundEffect is a high-fidelity sound effect generation model built for real-world content creation, offering strong environmental richness, broad category coverage, and reliable duration controllability. Trained on large-scale, high-quality data, it generates consistent audio from text prompts across natural ambience, urban scenes, creatures, human actions, and music-like clips. It is well suited for film and game production, interactive experiences, and data synthesis pipelines.
22
+ - **MOSS‑TTS‑Realtime**: MOSS-TTS-Realtime is a context-aware, multi-turn streaming TTS foundation model designed for real-time voice agents. Unlike conventional TTS that synthesizes replies in isolation, it conditions generation on multi-turn dialogue history—including both textual and acoustic signals from prior user speech—so responses stay coherent, consistent, and natural across turns. With low-latency incremental synthesis and strong voice stability, it enables truly conversational, human-like real-time speech experiences.
23
+
24
+
25
+ ## Released Models
26
+
27
+ | Model | Architecture | Size | Model Card | Hugging Face |
28
+ |---|---|---:|---|---|
29
+ | **MOSS-TTS** | MossTTSDelay | 8B | [moss_tts_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTS) |
30
+ | | MossTTSLocal | 1.7B | [moss_tts_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer) |
31
+ | **MOSS‑TTSD‑V1.0** | MossTTSDelay | 8B | [moss_ttsd_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_ttsd_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTSD-v1.0) |
32
+ | **MOSS‑VoiceGenerator** | MossTTSDelay | 1.7B | [moss_voice_generator_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_voice_generator_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-Voice-Generator) |
33
+ | **MOSS‑SoundEffect** | MossTTSDelay | 8B | [moss_sound_effect_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_sound_effect_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-SoundEffect) |
34
+ | **MOSS‑TTS‑Realtime** | MossTTSRealtime | 1.7B | [moss_tts_realtime_model_card.md](https://github.com/OpenMOSS/MOSS-TTS/blob/main/moss_tts_realtime_model_card.md) | 🤗 [Huggingface](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Realtime) |
35
+
36
+ <br>
37
+
38
+ # MOSS-TTSD
39
+
40
+ **MOSS-TTSD** is a long-form spoken dialogue generation model that enables highly expressive multi-party conversational speech synthesis across multiple languages. It supports continuous long-duration generation, flexible multi-speaker dialogue control, and state-of-the-art zero-shot voice cloning with only short reference audio. MOSS-TTSD is designed for real-world long-form content creation, including podcasts, audiobook, sports and esports commentary, dubbing, crosstalk, and entertainment scenarios.
41
+
42
+
43
+ ## 1. Overview
44
+
45
+ ### 1.1 TTS Family Positioning
46
+ MOSS-TTSD is the Long-Form Dialogue Specialist in our open-source TTS Family. While our foundational models focus on high-fidelity single-speaker synthesis, MOSS-TTSD extends this capability into the realm of complex, multi-party interactions. It is designed to bridge the gap between distinct audio samples and cohesive, continuous conversation.
47
+
48
+ **Design Goals**
49
+ - **Authentic Interaction**: Capturing the natural rhythm, overlaps, and dynamics of human conversation.
50
+ - **Sustained Coherence**: Maintaining speaker identity and contextual consistency over extended durations (up to 1 hour).
51
+ - **Production Adaptability**: Serving diverse high-end scenarios from rigorous audiobook narration to dynamic sports commentary.
52
+
53
+ ### 1.2 Key Capabilities
54
+ MOSS-TTSD transforms static text into living conversations, offering features specifically optimized for multi-speaker environments:
55
+
56
+ - **Multi-Party Conversational Generation** — Unlike traditional TTS which optimizes for reading, MOSS-TTSD masters the rhythm of conversation. It supports 1 to 5 speakers with flexible control, handling natural turn-taking, overlapping speech patterns, and distinct persona maintenance.
57
+
58
+ - **Extreme Long-Context Modeling** — Moving beyond short-sentence generation, the model is architected for stability over long durations, supporting up to 60 minutes of coherent audio in a single session without losing speaker identity or prosodic quality.
59
+
60
+ - **Diverse Scenario Adaptation** — The model is fine-tuned on high-variability scenarios to handle different speaking styles:
61
+ - Conversational Media: AI Podcasts, Interviews.
62
+ - Dynamic Commentary: High-energy Sports/Esports shouting and analysis.
63
+ - Entertainment: Audiobooks (narrator + characters), Dubbing, and Crosstalk (Xiangsheng).
64
+
65
+ - **Multilingual & Zero-Shot Cloning** — Features state-of-the-art zero-shot voice cloning requiring only short reference audio (3-10s), with robust cross-lingual performance across major languages including Chinese, English, Japanese, and European languages.
66
+
67
+ ### 1.3 Model Architecture
68
+
69
+ MOSS-TTSD is built on top of **Architecture A: Delay Pattern (MossTTSDelay)** from our MOSS-TTS foundation model — a single Transformer backbone with multi-head parallel prediction using delay scheduling for multi-codebook audio tokens.
70
+ <!-- For full architecture details, see **`moss_tts_delay/moss_tts_delay_architecture.md`**. -->
71
+
72
+ ### 1.4 Released Models
73
+
74
+ | Model | Architecture | NVQ | Parameters |
75
+ |-------|-------------|-----|------------|
76
+ | MOSS-TTSD | Architecture A: Delay Pattern (MossTTSDelay) | 16 | 8B |
77
+
78
+ **Recommended decoding hyperparameters**
79
+
80
+ | Model | audio_temperature | audio_top_p | audio_top_k | audio_repetition_penalty |
81
+ |---|---:|---:|---:|---:|
82
+ | **MOSS-TTSD** | 1.1 | 0.9 | 50 | 1.1 |
83
+
84
+ ## 2. Quick Start
85
+
86
+ MOSS-TTSD uses a **continuation** workflow: provide reference audio for each speaker, their transcripts as a prefix, and the dialogue text to generate. The model continues in each speaker's identity.
87
+
88
+ ```python
89
+ import os
90
+ from pathlib import Path
91
+ import torch
92
+ import torchaudio
93
+ from transformers import AutoModel, AutoProcessor
94
+
95
+ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTSD"
96
+ audio_tokenizer_name_or_path = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
97
+ device = "cuda" if torch.cuda.is_available() else "cpu"
98
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
99
+
100
+ processor = AutoProcessor.from_pretrained(
101
+ pretrained_model_name_or_path,
102
+ trust_remote_code=True,
103
+ codec_path=audio_tokenizer_name_or_path,
104
+ )
105
+ processor.audio_tokenizer = processor.audio_tokenizer.to(device)
106
+ processor.audio_tokenizer.eval()
107
+
108
+ model = AutoModel.from_pretrained(
109
+ pretrained_model_name_or_path,
110
+ trust_remote_code=True,
111
+ attn_implementation="flash_attention_2",
112
+ torch_dtype=dtype,
113
+ ).to(device)
114
+ model.eval()
115
+
116
+ # --- Inputs ---
117
+
118
+ prompt_audio_speaker1 = "https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_demo/reference_02_s1.wav"
119
+ prompt_audio_speaker2 = "https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_demo/reference_02_s2.wav"
120
+ prompt_text_speaker1 = "[S1] In short, we embarked on a mission to make America great again for all Americans."
121
+ prompt_text_speaker2 = "[S2] NVIDIA reinvented computing for the first time after 60 years. In fact, Erwin at IBM knows quite well that the computer has largely been the same since the 60s."
122
+
123
+ text_to_generate = "[S1] Listen, let's talk business. China. I'm hearing things. People are saying they're catching up. Fast. What's the real scoop? Their AI—is it a threat? [S2] Well, the pace of innovation there is extraordinary, honestly. They have the researchers, and they have the drive. [S1] Extraordinary? I don't like that. I want us to be extraordinary. Are they winning? [S2] I wouldn't say winning, but their progress is very promising. They are building massive clusters. They're very determined. [S1] Promising. There it is. I hate that word. When China is promising, it means we're losing. It's a disaster, Jensen. A total disaster. "
124
+
125
+ # --- Load & resample audio ---
126
+
127
+ target_sr = int(processor.model_config.sampling_rate)
128
+ wav1, sr1 = torchaudio.load(prompt_audio_speaker1)
129
+ wav2, sr2 = torchaudio.load(prompt_audio_speaker2)
130
+
131
+ if wav1.shape[0] > 1:
132
+ wav1 = wav1.mean(dim=0, keepdim=True)
133
+ if wav2.shape[0] > 1:
134
+ wav2 = wav2.mean(dim=0, keepdim=True)
135
+ if sr1 != target_sr:
136
+ wav1 = torchaudio.functional.resample(wav1, sr1, target_sr)
137
+ if sr2 != target_sr:
138
+ wav2 = torchaudio.functional.resample(wav2, sr2, target_sr)
139
+
140
+ # --- Build conversation ---
141
+
142
+ reference_audio_codes = processor.encode_audios_from_wav([wav1, wav2], sampling_rate=target_sr)
143
+ concat_prompt_wav = torch.cat([wav1, wav2], dim=-1)
144
+ prompt_audio = processor.encode_audios_from_wav([concat_prompt_wav], sampling_rate=target_sr)[0]
145
+
146
+ full_text = f"{prompt_text_speaker1} {prompt_text_speaker2} {text_to_generate}"
147
+
148
+ conversations = [
149
+ [
150
+ processor.build_user_message(
151
+ text=full_text,
152
+ reference=reference_audio_codes,
153
+ ),
154
+ processor.build_assistant_message(
155
+ audio_codes_list=[prompt_audio]
156
+ ),
157
+ ],
158
+ ]
159
+
160
+ # --- Inference ---
161
+
162
+ batch_size = 1
163
+
164
+ save_dir = Path("output")
165
+ save_dir.mkdir(exist_ok=True, parents=True)
166
+ sample_idx = 0
167
+ with torch.no_grad():
168
+ for start in range(0, len(conversations), batch_size):
169
+ batch_conversations = conversations[start : start + batch_size]
170
+ batch = processor(batch_conversations, mode="continuation")
171
+ input_ids = batch["input_ids"].to(device)
172
+ attention_mask = batch["attention_mask"].to(device)
173
+
174
+ outputs = model.generate(
175
+ input_ids=input_ids,
176
+ attention_mask=attention_mask,
177
+ max_new_tokens=2000,
178
+ )
179
+
180
+ for message in processor.decode(outputs):
181
+ for seg_idx, audio in enumerate(message.audio_codes_list):
182
+ torchaudio.save(save_dir / f"{sample_idx}_{seg_idx}.wav", audio.unsqueeze(0), processor.model_config.sampling_rate)
183
+ sample_idx += 1
184
+
185
+ ```
186
+
187
+ ### Input Types
188
+
189
+ **UserMessage**
190
+
191
+ | Field | Type | Required | Description |
192
+ |---|---|---:|---|
193
+ | `text` | `str` | Yes | Full dialogue text including speaker tags (`[S1]`, `[S2]`, ...) and prompt transcripts. |
194
+ | `reference` | `List` | Yes | Per-speaker reference audio codes from `processor.encode_audios_from_wav()`. |
195
+
196
+ **AssistantMessage**
197
+
198
+ | Field | Type | Required | Description |
199
+ |---|---|---:|---|
200
+ | `audio_codes_list` | `List` | Yes | Concatenated prompt audio codes for all speakers. |
201
+
202
+ ### Generation Hyperparameters
203
+
204
+ | Parameter | Type | Default | Description |
205
+ |---|---|---:|---|
206
+ | `max_new_tokens` | `int` | — | Controls total generated audio tokens. **1s ≈ 12.5 tokens**. |
207
+ | `audio_temperature` | `float` | 1.1 | Higher values increase variation; lower values stabilize prosody. |
208
+ | `audio_top_p` | `float` | 0.9 | Nucleus sampling cutoff. |
209
+ | `audio_top_k` | `int` | 50 | Top-K sampling. |
210
+ | `audio_repetition_penalty` | `float` | 1.1 | >1.0 discourages repeating patterns. |
211
+
212
+
213
+ ## 3. Evaluation
214
+ ### Objective Evaluation(TTSD-eval)
215
+
216
+
217
+
218
+ We introduce a robust evaluation framework leveraging **MMS-FA** for alignment and **wespeaker** for embedding extraction to ensure precise speaker attribution.
219
+
220
+
221
+
222
+ - **Method**: Forced-alignment based segmentation + Similarity-based speaker verification.
223
+
224
+ - **Metrics**:
225
+ - **Speaker Attribution Accuracy (ACC)**
226
+ - **Speaker Similarity (SIM)**
227
+ - **Word Error Rate (WER)** computed using **Whisper-large-v3**.
228
+
229
+ - **Dataset**: 100 multi-turn dialogues (CN/EN) spanning 30s–720s. Covers diverse scenarios including Podcasts, TV dubbing, and Crosstalk. Code and data coming soon.
230
+ <br>
231
+
232
+ | Model | ZH - SIM | ZH - ACC | ZH - WER | EN - SIM | EN - ACC | EN - WER |
233
+ | :--- | :---: | :---: | :---: | :---: | :---: | :---: |
234
+ | **Comparison with Open-Source Models** | | | | | | |
235
+ | MOSS-TTSD | **0.7949** | **0.9587** | **0.0485** | **0.7326** | **0.9626** | 0.0988 |
236
+ | MOSS-TTSD v0.7 | 0.7423 | 0.9391 | 0.0517 | 0.6743 | 0.9266 | 0.1612 |
237
+ | Vibevoice 7B | 0.7590 | 0.9222 | 0.0570 | 0.7140 | 0.9554 | **0.0946** |
238
+ | Vibevoice 1.5 B | 0.7415 | 0.8798 | 0.0818 | 0.6961 | 0.9353 | 0.1133 |
239
+ | FireRedTTS2 | 0.7383 | 0.9022 | 0.0768 | - | - | - |
240
+ | Higgs Audio V2 | - | - | - | 0.6860 | 0.9025 | 0.2131 |
241
+ | **Comparison with Proprietary Models** | | | | | | |
242
+ | Eleven V3 | 0.6970 | 0.9653 | **0.0363** | 0.6730 | 0.9498 | **0.0824** |
243
+ | MOSS-TTSD (elevenlabs_voice) | **0.8165** | **0.9736** | 0.0391 | **0.7304** | **0.9565** | 0.1005 |
244
+ | | | | | | | |
245
+ | gemini-2.5-pro-preview-tts | - | - | - | 0.6786 | 0.9537 | **0.0859** |
246
+ | gemini-2.5-flash-preview-tts | - | - | - | 0.7194 | 0.9511 | 0.0871 |
247
+ | MOSS-TTSD (gemini_voice) | - | - | - | **0.7893** | **0.9655** | 0.0984 |
248
+ | | | | | | | |
249
+ | Doubao_Podcast | 0.8034 | 0.9606 | **0.0472** | - | - | - |
250
+ | MOSS-TTSD (doubao_voice) | **0.8226** | **0.9630** | 0.0571 | - | - | - |
251
+
252
+ ### Subjective Evaluation
253
+ For open-source models, annotators are asked to score each sample pair in terms of speaker attribution accuracy, voice similarity, prosody, and overall quality. Following the methodology of the LMSYS Chatbot Arena, we compute Elo ratings and confidence intervals for each dimension.
254
+ ![alt text](assets/VS_Open-Source_Models.png)
255
+
256
+ For closed-source models, annotators are only asked to choose the overall preferred one in each pair, and we compute the win rate accordingly.
257
+ ![alt text](assets/VS_Proprietary_Models1.png)
258
+ ![alt text](assets/VS_Proprietary_Models2.png)
modeling_moss_tts.py CHANGED
@@ -395,7 +395,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
395
  input_ids: torch.LongTensor,
396
  attention_mask: Optional[torch.Tensor] = None,
397
  max_new_tokens: Optional[int] = None,
398
- text_temperature: float = 1.2,
399
  text_top_p: float = 0.9,
400
  text_top_k: int = 50,
401
  audio_temperature: Optional[float] = None,
@@ -460,14 +460,14 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
460
  generation_ids = input_ids[:]
461
  is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
462
 
463
- # 三个阶段: 1. audio; 2. audio not delay; 3. audio delay
464
- audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 的时候表示阶段1;
465
  torch_int64_max = torch.iinfo(torch.int64).max
466
- delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # 最大值的时候表示阶段2;
467
 
468
- # 考虑 continuation audio_start 已经在 input_ids 中的情况;
469
- # NOTE 注意我们目前不考虑任何输入已经开始 delay 的情况;
470
- # 需要同时考虑 continuation 和直接生成的情况;
471
  is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
472
  audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
473
  audio_start_mask = is_continuation & (audio_start_indices != -1)
@@ -480,7 +480,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
480
  pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
481
 
482
 
483
- # 注意 time_step 未必表示对于实际对话时,当前输出token的位置,因为有续写的情况;
484
  for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
485
  outputs = self(
486
  input_ids=current_input_ids,
@@ -492,9 +492,10 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
492
 
493
  next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
494
  next_token_logits[0] = next_token_logits[0].clone()
495
- # 1. 先处理 text token;
496
  next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
497
- # 第二个 audio_assistant_delay_slot_token_id audio_end 是不需要采样的,audio_start, 每一个 audio_assistant_gen_slot_token_ids 和第一个 audio_assistant_delay_slot_token_id 是需要采样的;
 
498
  next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
499
  is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
500
  next_text_token[is_audio_eos] = self.config.audio_end_token_id
@@ -507,7 +508,7 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
507
  if time_step <= n_vq:
508
  next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
509
 
510
- # 文本层不使用重复惩罚;
511
  next_text_token[sampling_text_mask] = sample_token(
512
  logits=next_token_logits[0][sampling_text_mask],
513
  top_p=text_top_p,
@@ -515,15 +516,15 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
515
  do_sample=text_do_sample
516
  )
517
  is_audio[next_text_token == self.config.audio_start_token_id] = True
518
- # 只存在一种停止逻辑,即 next_text_token = <|im_end|>;
519
  is_stopping[next_text_token == self.config.im_end_token_id] = True
520
 
521
- # 2. 再处理 audio tokens;
522
- # audio_start audio_end 之外的内容直接pad,默认是 pad,我们只需要填充有值的部分即可;
523
  next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
524
 
525
- # 需要考虑的是与 audio_start 的距离;
526
- # 先查看是否是pad的情况; true 表示有值;
527
  pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
528
  post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
529
  post_audio_mask[delayed_lengths == torch_int64_max] = True
@@ -531,18 +532,32 @@ class MossTTSDelayModel(MossTTSDelayPreTrainedModel):
531
  next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
532
 
533
  if sampling_audio_mask.sum() > 0:
534
- audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
 
535
  audio_logits[..., self.config.audio_pad_code] = float('-inf')
536
- next_audio_tokens[sampling_audio_mask] = sample_token(
 
 
 
 
 
 
 
 
 
 
 
 
537
  logits=audio_logits,
538
- prev_tokens=generation_ids[:, :, 1:],
539
  repetition_penalty=audio_repetition_penalty,
540
  top_p=audio_top_p,
541
  top_k=audio_top_k,
542
  do_sample=audio_do_sample
543
  )
 
544
 
545
- # 这里显示的是下一个时间步时可以直接使用的 audio_lengths delayed_lengths 的状态;
546
  # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
547
  # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
548
  audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
 
395
  input_ids: torch.LongTensor,
396
  attention_mask: Optional[torch.Tensor] = None,
397
  max_new_tokens: Optional[int] = None,
398
+ text_temperature: float = 1.1,
399
  text_top_p: float = 0.9,
400
  text_top_k: int = 50,
401
  audio_temperature: Optional[float] = None,
 
460
  generation_ids = input_ids[:]
461
  is_stopping = torch.zeros(batch_size, dtype=torch.bool, device=device)
462
 
463
+ # Three phases: 1) non-audio, 2) audio generation before delay, 3) delayed audio.
464
+ audio_lengths = torch.zeros(batch_size, dtype=torch.int64, device=device) # 0 means phase 1.
465
  torch_int64_max = torch.iinfo(torch.int64).max
466
+ delayed_lengths = torch.full((batch_size,), torch_int64_max, dtype=torch.int64, device=device) # int64 max means phase 2.
467
 
468
+ # Handle continuation where audio_start is already present in input_ids.
469
+ # NOTE: delayed-audio continuation is currently not handled.
470
+ # Support both continuation and fresh generation.
471
  is_continuation = (input_ids[:, -1, 0] == self.config.audio_start_token_id) | (input_ids[:, -1, 0] == self.config.audio_assistant_gen_slot_token_id)
472
  audio_start_indices = find_last_equal_C(input_ids[..., 0], self.config.audio_start_token_id)
473
  audio_start_mask = is_continuation & (audio_start_indices != -1)
 
480
  pre_exclude_mask1[[self.config.audio_assistant_gen_slot_token_id, self.config.audio_assistant_delay_slot_token_id]] = False
481
 
482
 
483
+ # time_step is a generation step, not the absolute dialogue position under continuation.
484
  for time_step in tqdm(range(max_new_tokens), desc=f"Generating bs{batch_size} ..."):
485
  outputs = self(
486
  input_ids=current_input_ids,
 
492
 
493
  next_token_logits = [logit[:, -1, :] / text_temperature if logit_idx == 0 else logit[:, -1, :] / audio_temperature for logit_idx, logit in enumerate(outputs.logits)] # List, len=n_vq+1, [batch_size, 1, vocab_size];
494
  next_token_logits[0] = next_token_logits[0].clone()
495
+ # 1) Process text token first.
496
  next_text_token = torch.full((batch_size,), self.config.pad_token_id, device=device)
497
+ # The second delay-slot token and audio_end are fixed; audio_start, each gen-slot token,
498
+ # and the first delay-slot token are sampled.
499
  next_text_token[~is_stopping & (delayed_lengths < n_vq)] = self.config.audio_assistant_delay_slot_token_id
500
  is_audio_eos = ~is_stopping & (delayed_lengths == n_vq)
501
  next_text_token[is_audio_eos] = self.config.audio_end_token_id
 
508
  if time_step <= n_vq:
509
  next_token_logits[0][..., self.config.im_end_token_id] = float('-inf')
510
 
511
+ # No repetition penalty on the text channel.
512
  next_text_token[sampling_text_mask] = sample_token(
513
  logits=next_token_logits[0][sampling_text_mask],
514
  top_p=text_top_p,
 
516
  do_sample=text_do_sample
517
  )
518
  is_audio[next_text_token == self.config.audio_start_token_id] = True
519
+ # Single stop condition: next_text_token == <|im_end|>.
520
  is_stopping[next_text_token == self.config.im_end_token_id] = True
521
 
522
+ # 2) Then process audio tokens.
523
+ # Outside [audio_start, audio_end], keep pad; only fill valid positions.
524
  next_audio_tokens = torch.full((batch_size, n_vq), self.config.audio_pad_code, device=device)
525
 
526
+ # Build masks based on distance from audio_start.
527
+ # True means this position should contain a real audio token.
528
  pre_audio_mask = audio_lengths.unsqueeze(1) > torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq)
529
  post_audio_mask = torch.arange(n_vq, dtype=int, device=device).expand(batch_size, n_vq) > delayed_lengths.unsqueeze(1) - 1
530
  post_audio_mask[delayed_lengths == torch_int64_max] = True
 
532
  next_audio_tokens[~sampling_audio_mask] = self.config.audio_pad_code
533
 
534
  if sampling_audio_mask.sum() > 0:
535
+ # audio_logits = torch.stack(next_token_logits[1:], dim=1)[sampling_audio_mask] # torch.stack -> [batch_size, n_vq - 1, vocab_size]
536
+ audio_logits = torch.stack(next_token_logits[2:], dim=1)[sampling_audio_mask[:, 1:]]
537
  audio_logits[..., self.config.audio_pad_code] = float('-inf')
538
+
539
+ audio_ch0_logits = next_token_logits[1][sampling_audio_mask[:, 0]]
540
+ audio_ch0_logits[..., 1024] = float('-inf')
541
+ next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = sample_token(
542
+ logits=audio_ch0_logits,
543
+ prev_tokens=generation_ids[:, :, 1],
544
+ repetition_penalty=audio_repetition_penalty,
545
+ top_p=audio_top_p,
546
+ top_k=audio_top_k,
547
+ do_sample=audio_do_sample
548
+ )
549
+ # print(f"{next_audio_tokens[:, 0][sampling_audio_mask[:, 0]] = }")
550
+ next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = sample_token(
551
  logits=audio_logits,
552
+ prev_tokens=generation_ids[:, :, 2:],
553
  repetition_penalty=audio_repetition_penalty,
554
  top_p=audio_top_p,
555
  top_k=audio_top_k,
556
  do_sample=audio_do_sample
557
  )
558
+ # print(f"{next_audio_tokens[:, 1:][sampling_audio_mask[:, 1:]] = }")
559
 
560
+ # Update audio_lengths and delayed_lengths for direct use in the next step.
561
  # audio_lengths[(next_text_token == self.audio_start_token_id) & (audio_lengths > 0)] += 1
562
  # audio_lengths[(next_text_token == self.audio_start_token_id) | (next_text_token == self.audio_assistant_gen_slot_token_id)] += 1
563
  audio_lengths[(next_text_token == self.config.audio_start_token_id) | (next_text_token == self.config.audio_assistant_gen_slot_token_id) | (next_text_token == self.config.audio_assistant_delay_slot_token_id)] += 1
processing_moss_tts.py CHANGED
@@ -555,7 +555,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
555
  truncation: bool,
556
  ) -> torch.Tensor:
557
  """
558
- 此时的 content 已经是带上了对话格式
559
  """
560
  if role == "user":
561
  audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
@@ -740,8 +740,8 @@ class MossTTSDelayProcessor(ProcessorMixin):
740
 
741
  def decode(self, output: List[Tuple[int, torch.Tensor]]):
742
  """
743
- 1. 这里不管怎样,都需要一个完整的 assistant generation ids;
744
- 2. 支持从任意位置进行截断;
745
  """
746
 
747
  genearted_messages = []
@@ -927,58 +927,34 @@ class MossTTSDelayProcessor(ProcessorMixin):
927
  for codes in audio_tokens_list
928
  ]
929
 
930
- # Align with legacy behavior: decode each sample with chunk_duration=8.0.
931
- # Streaming chunk decode currently supports batch_size=1 in MossAudioTokenizer.
932
- if hasattr(audio_tokenizer, "decode"):
933
- wav_list: List[torch.Tensor] = []
934
- for codes in codes_list:
935
- try:
936
- dec = audio_tokenizer.decode(
937
- codes,
938
- return_dict=True,
939
- chunk_duration=8.0,
940
- )
941
- except TypeError:
942
- # Compatibility fallback for tokenizers without chunk_duration arg.
943
- dec = audio_tokenizer.decode(
944
- codes,
945
- return_dict=True,
946
- )
947
-
948
- audio = dec.audio
949
- audio_lengths = dec.audio_lengths
950
- if audio is None:
951
- raise RuntimeError("audio_tokenizer.decode() returned empty audio.")
952
 
953
- if audio_lengths is None:
954
- cur_len = int(audio.shape[-1])
955
- else:
956
- cur_len = int(audio_lengths[0].item())
957
 
958
- if audio.ndim == 3:
959
- wav = audio[0, 0, :cur_len]
960
- elif audio.ndim == 2:
961
- wav = audio[0, :cur_len]
962
- else:
963
- raise RuntimeError(
964
- f"Unexpected audio shape from decode: {tuple(audio.shape)}"
965
- )
966
- wav_list.append(wav.contiguous().to(torch.float32).cpu())
967
- return wav_list
968
-
969
- if hasattr(audio_tokenizer, "batch_decode"):
970
- dec = audio_tokenizer.batch_decode(codes_list)
971
- audio = dec.audio
972
- audio_lengths = dec.audio_lengths
973
- if audio is None or audio_lengths is None:
974
- raise RuntimeError(
975
- "audio_tokenizer.batch_decode() returned empty outputs (audio/audio_lengths)."
976
- )
977
- wav_list = []
978
- for i in range(int(audio.shape[0])):
979
- length_i = int(audio_lengths[i].item())
980
- wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
981
- wav_list.append(wav)
982
- return wav_list
983
-
984
- raise RuntimeError("audio_tokenizer has neither decode() nor batch_decode().")
 
555
  truncation: bool,
556
  ) -> torch.Tensor:
557
  """
558
+ content is already formatted with the conversation template.
559
  """
560
  if role == "user":
561
  audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
 
740
 
741
  def decode(self, output: List[Tuple[int, torch.Tensor]]):
742
  """
743
+ 1. Always require complete assistant generation ids.
744
+ 2. Support truncation from arbitrary positions.
745
  """
746
 
747
  genearted_messages = []
 
927
  for codes in audio_tokens_list
928
  ]
929
 
930
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
931
+ nq = int(codes_list[0].shape[0])
932
+ max_t = max(int(c.shape[1]) for c in codes_list)
933
+ audio_codes = torch.zeros(
934
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
935
+ )
936
+ padding_mask = torch.zeros(
937
+ len(codes_list), max_t, device=device, dtype=torch.bool
938
+ )
939
+ for i, c in enumerate(codes_list):
940
+ t = int(c.shape[1])
941
+ audio_codes[:, i, :t] = c
942
+ padding_mask[i, :t] = True
943
+ dec = audio_tokenizer.decode(
944
+ audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
945
+ )
946
+ audio = dec.audio
947
+ audio_lengths = dec.audio_lengths
 
 
 
 
948
 
949
+ if audio is None or audio_lengths is None:
950
+ raise RuntimeError(
951
+ "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)."
952
+ )
953
 
954
+ # Return historical contract: list of 1D waveforms (T,)
955
+ wav_list: List[torch.Tensor] = []
956
+ for i in range(int(audio.shape[0])):
957
+ length_i = int(audio_lengths[i].item())
958
+ wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu()
959
+ wav_list.append(wav)
960
+ return wav_list