MOSS-TTS / processing_moss_tts.py
Li-Ruixiao's picture
Add initial implementation of MossTTSDelay model, configuration, and processing utilities
e4aa3d2
raw
history blame
24.8 kB
# coding=utf-8
# Copyright 2026 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, List, Optional, Tuple, Type, Union, Literal, Final
from dataclasses import dataclass
from pathlib import Path
import re
import torchaudio
import torch
from transformers import PreTrainedTokenizerBase, BatchFeature, ProcessorMixin, logging, AutoConfig, AutoModel, AutoTokenizer
from .configuration_moss_tts import MossTTSDelayConfig
logger = logging.get_logger(__name__)
AUDIO_PLACEHOLDER = "<|audio|>"
@dataclass
class Message:
pass
@dataclass
class UserMessage(Message):
text: Optional[str] = None
reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None
instruction: Optional[str] = None
tokens: Optional[int] = None
quality: Optional[str] = None
sound_event: Optional[str] = None
ambient_sound: Optional[str] = None
language: Optional[str] = None
def __post_init__(self):
template = """<user_inst>
- Reference(s):
{reference}
- Instruction:
{instruction}
- Tokens:
{tokens}
- Quality:
{quality}
- Sound Event:
{sound_event}
- Ambient Sound:
{ambient_sound}
- Language:
{language}
- Text:
{text}
</user_inst>"""
audio_codes_list = []
if self.reference is None:
reference = "None"
elif isinstance(self.reference, List):
reference = []
for speaker_idx, speaker_reference in enumerate(self.reference):
if speaker_reference is not None:
reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
reference = "\n".join(reference)
audio_codes_list = [speaker_reference for speaker_reference in self.reference if speaker_reference is not None]
else:
raise TypeError("`reference` should be exactly a list when it is not None.")
content = (
template
.replace("{reference}", str(reference))
.replace("{instruction}", str(self.instruction))
.replace("{tokens}", str(self.tokens))
.replace("{quality}", str(self.quality))
.replace("{sound_event}", str(self.sound_event))
.replace("{ambient_sound}", str(self.ambient_sound))
.replace("{language}", str(self.language))
.replace("{text}", str(self.text))
)
self._content = content
self._audio_codes_list = audio_codes_list
def to_dict(self):
return {
"role": "user",
"content": self._content,
"audio_codes_list": self._audio_codes_list
}
@dataclass
class AssistantMessage(Message):
audio_codes_list: List[Union[str, torch.Tensor]]
content: str = AUDIO_PLACEHOLDER
def to_dict(self):
return {
"role": "assistant",
"content": self.content,
"audio_codes_list": self.audio_codes_list
}
USER_MESSAGE_FIELDS = (
"text",
"reference",
"instruction",
"tokens",
"quality",
"sound_event",
"ambient_sound",
"language",
)
class MossTTSDelayProcessor(ProcessorMixin):
tokenizer_class = "AutoTokenizer"
audio_tokenizer_class = "AutoModel"
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
audio_tokenizer: AutoModel = None,
model_config: Optional[MossTTSDelayConfig] = None,
**kwargs
):
super().__init__(
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
**kwargs
)
if model_config is None:
model_config = MossTTSDelayConfig()
self.model_config = model_config
self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
self.newline_token_id = 198
self.audio_user_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_user_slot_token_id)
self.audio_assistant_gen_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_assistant_gen_slot_token_id)
self.audio_assistant_delay_slot_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_assistant_delay_slot_token_id)
self.audio_start_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_start_token_id)
self.audio_end_token = tokenizer.convert_ids_to_tokens(self.model_config.audio_end_token_id)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, trust_remote_code=True, **kwargs):
kwargs.pop("_from_auto")
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
audio_tokenizer_name_or_path = kwargs.pop("codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer")
assert isinstance(audio_tokenizer_name_or_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_name_or_path)}"
audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
return cls(
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
model_config=model_config,
**kwargs
)
def __call__(
self,
conversations: Union[Message, Dict, List[Message], List[Dict], List[List[Message]], List[List[Dict]]],
mode: str = "generation",
apply_chat_template: bool = True,
n_vq: Optional[int] = None
) -> BatchFeature:
"""
mode 只会在将 Message 转换为 to_dict 时起作用;
"""
if mode not in {"generation", "continuation"}:
raise RuntimeError
if isinstance(conversations, (Message, Dict)):
conversations = [conversations]
truncation = False
if mode == "continuation":
truncation = True
input_ids_list = []
for conversation in conversations:
if isinstance(conversation, (Message, Dict)):
conversation = [conversation]
if (mode == "generation") ^ (len(conversation) % 2 != 0):
raise ValueError
if (mode == "generation") ^ (conversation[-1]['role'] == "user"):
raise ValueError
unified_codes = []
for message_idx, message in enumerate(conversation):
message = self._normalize_message(message)
if apply_chat_template:
add_generation_prompt = mode == "generation" and message_idx == len(conversation) - 1
try:
content = self.tokenizer.apply_chat_template(
[{"role": message["role"], "content": message["content"]}],
add_generation_prompt=add_generation_prompt,
tokenize=False,
)
except TypeError:
try:
content = self.tokenizer.apply_chat_template(
[{"role": message["role"], "content": message["content"]}],
add_generation_prompt=add_generation_prompt,
)
except Exception:
logger.warning("apply_chat_template failed; fallback to raw content.")
content = message["content"]
else:
content = message['content']
audio_codes_list = []
for audio_codes in message["audio_codes_list"]:
if isinstance(audio_codes, torch.Tensor):
if n_vq is not None and audio_codes.shape[1] != n_vq:
raise RuntimeError("audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs.")
else:
audio_codes = self.encode_audios_from_path(audio_codes, n_vq)[0]
audio_codes_list.append(audio_codes)
unified_codes.append(self._get_unified_codes(message['role'], content, audio_codes_list, truncation))
unified_codes = torch.cat(unified_codes)
input_ids_list.append(unified_codes)
return self._pad(input_ids_list)
@staticmethod
def build_user_message(
text: Optional[str] = None,
reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None,
instruction: Optional[str] = None,
tokens: Optional[int] = None,
quality: Optional[str] = None,
sound_event: Optional[str] = None,
ambient_sound: Optional[str] = None,
language: Optional[str] = None,
) -> Dict:
if reference is not None and not isinstance(reference, list):
reference = [reference]
return UserMessage(
text=text,
reference=reference,
instruction=instruction,
tokens=tokens,
quality=quality,
sound_event=sound_event,
ambient_sound=ambient_sound,
language=language,
).to_dict()
@staticmethod
def build_assistant_message(
audio_codes_list: List[Union[str, torch.Tensor]],
content: str = AUDIO_PLACEHOLDER,
) -> Dict:
return AssistantMessage(
audio_codes_list=audio_codes_list,
content=content,
).to_dict()
def _normalize_message(self, message: Union[Message, Dict]) -> Dict:
if isinstance(message, Message):
return message.to_dict()
if not isinstance(message, dict):
raise TypeError("Each message must be a Message or dict.")
if "role" not in message:
raise ValueError("Message dict must include a 'role' field.")
if "content" in message and "audio_codes_list" in message:
return message
role = message["role"]
if role == "user":
kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS}
return self.build_user_message(**kwargs)
if role == "assistant":
return self.build_assistant_message(
audio_codes_list=message.get("audio_codes_list", []),
content=message.get("content", AUDIO_PLACEHOLDER),
)
raise ValueError(f"Unsupported role: {role}")
def _pad(self, input_ids_list: List[torch.Tensor]):
device = input_ids_list[0].device
lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device)
pad_input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=self.model_config.audio_pad_code, padding_side="left")
other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze(1) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0)
pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id
attention_mask = torch.zeros(pad_input_ids.shape[0], pad_input_ids.shape[1], device=device)
attention_mask[~other_channel_mask] = 1
attention_mask = attention_mask.bool()
return {
"input_ids": pad_input_ids, # [batch_size, seqlen, n_vq]
"attention_mask": attention_mask,
}
@staticmethod
def _replace_audio_placeholders(
content: str,
lengths: List[int],
n_vq: int,
gen_slot_token: str,
delay_slot_token: str,
audio_start_token: str,
audio_end_token: str
) -> str:
if n_vq < 1:
raise ValueError(f"n_vq must be >= 1, got {n_vq}")
num_placeholders = content.count(AUDIO_PLACEHOLDER)
if num_placeholders != len(lengths):
raise ValueError(
f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) "
f"does not match lengths ({len(lengths)})"
)
def build_audio_block(length: int) -> str:
if length < 0:
raise ValueError(f"length must be >= 0, got {length}")
if length == 0:
return f"{audio_start_token}{audio_end_token}"
step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
return f"{audio_start_token}{step_tokens}{audio_end_token}"
lengths_iter = iter(lengths)
def replacer(match: re.Match) -> str:
length = next(lengths_iter)
return build_audio_block(length)
result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
return result
@staticmethod
def _merge_consecutive_audio_placeholders(
content: str,
audio_codes_list: List[torch.Tensor],
) -> Tuple[str, List[torch.Tensor]]:
matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content))
if len(matches) <= 1:
return content, audio_codes_list
if len(matches) != len(audio_codes_list):
raise ValueError("Audio placeholders do not match the provided audio codes list.")
new_audio_codes_list = []
new_parts = []
last_pos = 0
i = 0
while i < len(matches):
j = i
while (
j + 1 < len(matches)
and content[matches[j].end():matches[j + 1].start()].strip() == ""
):
j += 1
new_parts.append(content[last_pos:matches[i].start()])
new_parts.append(AUDIO_PLACEHOLDER)
last_pos = matches[j].end()
if j == i:
new_audio_codes_list.append(audio_codes_list[i])
else:
new_audio_codes_list.append(torch.cat(audio_codes_list[i:j + 1], dim=0))
i = j + 1
new_parts.append(content[last_pos:])
return "".join(new_parts), new_audio_codes_list
@staticmethod
def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor:
delayed_tokens = torch.full(
(codes.shape[0] + codes.shape[1] - 1, codes.shape[1]),
pad_code,
device=codes.device,
dtype=codes.dtype,
)
for i in range(codes.shape[1]):
delayed_tokens[i: i + codes.shape[0], i] = codes[:, i]
return delayed_tokens
@staticmethod
def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor:
tokens = torch.full(
(delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]),
0,
device=delay_codes.device,
dtype=delay_codes.dtype,
)
for i in range(delay_codes.shape[1]):
tokens[:, i] = delay_codes[i: i + tokens.shape[0], i]
return tokens
def _get_unified_codes(self, role: str, content: str, audio_codes_list: List[Union[str, torch.Tensor]], truncation: bool) -> torch.Tensor:
"""
此时的 content 已经是带上了对话格式
"""
if role == "user":
audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
else:
audio_gen_slot_token = self.audio_assistant_gen_slot_token
audio_delay_slot_token = self.audio_assistant_delay_slot_token
if len(audio_codes_list):
n_vq = audio_codes_list[0].shape[1]
else:
n_vq = self.model_config.n_vq
if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content:
content, audio_codes_list = self._merge_consecutive_audio_placeholders(
content, audio_codes_list
)
content = self._replace_audio_placeholders(
content=content,
lengths=[len(audio_codes) for audio_codes in audio_codes_list],
n_vq=n_vq,
gen_slot_token=audio_gen_slot_token,
delay_slot_token=audio_delay_slot_token,
audio_start_token=self.audio_start_token,
audio_end_token=self.audio_end_token,
)
text_codes = torch.tensor(self.tokenizer.encode(content), device=audio_codes_list[0].device if audio_codes_list else None)
audio_start_indices = torch.where(text_codes == self.model_config.audio_start_token_id)[0]
audio_end_indices = torch.where(text_codes == self.model_config.audio_end_token_id)[0]
if len(audio_start_indices) != len(audio_codes_list) or len(audio_end_indices) != len(audio_codes_list):
raise ValueError("Audio placeholders do not match the provided audio codes list.")
delay_audio_codes_list = []
if len(audio_codes_list) == 0:
delay_audio_codes_list = torch.full(
(len(text_codes), n_vq),
self.model_config.audio_pad_code,
device=text_codes.device,
dtype=text_codes.dtype,
)
else:
prefix_idx = 0
for audio_start_idx, audio_end_idx, audio_codes in zip(audio_start_indices, audio_end_indices, audio_codes_list):
delay_audio_codes = self.apply_delay_pattern(audio_codes, self.model_config.audio_pad_code)
pad_codes = torch.full(
(audio_start_idx - prefix_idx + 1, n_vq),
self.model_config.audio_pad_code,
device=audio_codes.device,
dtype=audio_codes.dtype,
)
delay_audio_codes_list.extend([pad_codes, delay_audio_codes])
prefix_idx = audio_end_idx
if truncation:
delay_audio_codes_list[-1] = delay_audio_codes_list[-1][:-(n_vq - 1), :]
else:
pad_codes = torch.full(
(len(text_codes) - audio_end_indices[-1], n_vq),
self.model_config.audio_pad_code,
device=audio_codes_list[0].device,
dtype=audio_codes_list[0].dtype,
)
delay_audio_codes_list.append(pad_codes)
delay_audio_codes_list = torch.cat(delay_audio_codes_list)
if text_codes.shape[0] != delay_audio_codes_list.shape[0]:
text_codes = text_codes[:delay_audio_codes_list.shape[0]]
unified_codes = torch.cat([text_codes.unsqueeze(1), delay_audio_codes_list], dim=1)
return unified_codes
def _parse_text_codes(self, start_length, text_codes):
text = self.tokenizer.decode(text_codes)
prefix = self.tokenizer.decode(text_codes[:start_length])
text = text[len(prefix):]
AUDIO_PATTERN = re.compile(
rf'(?:{self.audio_start_token})?'
rf'(?:{self.audio_assistant_gen_slot_token})*'
rf'(?:{self.audio_assistant_delay_slot_token})*'
rf'{self.audio_end_token}'
)
def normalize_audio_segments(text: str) -> str:
def repl(match: re.Match) -> str:
seg = match.group(0)
# 如果片段内包含至少一个 gen_slot,则替换为 <|audio|>
if self.audio_assistant_gen_slot_token in seg:
return AUDIO_PLACEHOLDER
# 否则直接删除
return ""
return AUDIO_PATTERN.sub(repl, text)
return normalize_audio_segments(text)
def _parse_audio_codes(self, start_length, audio_codes):
# De-delay back to [T', n_vq]
audio_codes = self.apply_de_delay_pattern(audio_codes)
# Rows that are all pad are separators between real audio segments.
is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
non_pad = ~is_pad
if not non_pad.any():
return []
idx = torch.nonzero(non_pad).squeeze(1)
breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
if breaks.numel() == 0:
segments_idx = [idx]
else:
segments_idx = torch.split(idx, breaks.tolist())
audio_codes_list = [audio_codes[s] for s in segments_idx]
decoded_audio_list = []
for segment_codes in audio_codes_list:
decoded_segment = self.decode_audio_codes([segment_codes])
if len(decoded_segment) > 0:
decoded_audio_list.append(decoded_segment[0])
# Keep codec causal context by decoding the whole first segment first,
# then trim at waveform level according to start_length ratio.
if start_length > 0 and len(audio_codes_list) > 0 and len(decoded_audio_list) > 0:
first_codes_length = audio_codes_list[0].shape[0]
if first_codes_length > 0:
trim_ratio = max(0.0, min(float(start_length) / float(first_codes_length), 1.0))
first_audio = decoded_audio_list[0]
if trim_ratio >= 1.0:
decoded_audio_list = decoded_audio_list[1:]
elif trim_ratio > 0.0:
trim_samples = int(first_audio.shape[-1] * trim_ratio)
decoded_audio_list[0] = first_audio[..., trim_samples:]
return decoded_audio_list
def decode(self, output: List[Tuple[int, torch.Tensor]]):
"""
1. 这里不管怎样,都需要一个完整的 assistant generation ids;
2. 支持从任意位置进行截断;
"""
genearted_messages = []
for start_length, generation_ids in output:
content = self._parse_text_codes(start_length, generation_ids[:, 0])
audio_codes_list = self._parse_audio_codes(start_length, generation_ids[:, 1:])
if content == "":
message = None
else:
message = AssistantMessage(
content=content,
audio_codes_list=audio_codes_list
)
genearted_messages.append(message)
return genearted_messages
@staticmethod
def loudness_normalize(wav: torch.Tensor, target_dbfs: float = -20, gain_range: tuple[float, float] = (-3.0, 3.0)) -> torch.Tensor:
wav = wav.to(torch.float32)
if wav.numel() == 0: return wav
rms = torch.sqrt(torch.mean(wav ** 2))
current_dbfs = 20.0 * torch.log10(rms + 1e-9)
gain = float(target_dbfs - current_dbfs)
gain = max(gain_range[0], min(gain, gain_range[1]))
factor = 10.0 ** (gain / 20.0)
return wav * factor
def encode_audios_from_wav(self, wav_list: List[torch.Tensor], sampling_rate: int, n_vq: int = None):
if isinstance(wav_list, torch.Tensor):
wav_list = [wav_list]
wav_list_ = []
resample = False
if sampling_rate != self.model_config.sampling_rate:
resample = True
for wav in wav_list:
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
if resample:
wav = torchaudio.functional.resample(waveform=wav, orig_freq=sampling_rate, new_freq=self.model_config.sampling_rate)
wav_list_.append(self.loudness_normalize(wav.squeeze(0)))
return self.audio_tokenizer.encode(wav_list_, n_vq)
def encode_audios_from_path(self, wav_path_list: List[str], n_vq: int = None):
if isinstance(wav_path_list, str):
wav_path_list = [wav_path_list]
wav_list = []
sampling_rate = None
for wav_path in wav_path_list:
wav, sr = torchaudio.load(wav_path)
if sampling_rate is None:
sampling_rate = sr
elif sampling_rate != sr:
raise ValueError("sampling_rate of audios in the same batch should be the same.")
wav_list.append(wav)
return self.encode_audios_from_wav(wav_list, sampling_rate, n_vq)
def decode_audio_codes(self, audio_tokens_list: List[torch.Tensor]):
return self.audio_tokenizer.decode(audio_tokens_list)