fix: remove unnecessary information
Browse files- processing_moss_tts.py +23 -38
processing_moss_tts.py
CHANGED
|
@@ -88,7 +88,7 @@ class UserMessage(Message):
|
|
| 88 |
reference = []
|
| 89 |
for speaker_idx, speaker_reference in enumerate(self.reference):
|
| 90 |
if speaker_reference is not None:
|
| 91 |
-
reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
|
| 92 |
reference = "\n".join(reference)
|
| 93 |
audio_codes_list = [
|
| 94 |
speaker_reference
|
|
@@ -333,7 +333,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 333 |
)
|
| 334 |
|
| 335 |
if len(paths) > 0:
|
| 336 |
-
encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
|
| 337 |
if len(encoded_from_paths) != len(paths):
|
| 338 |
raise RuntimeError(
|
| 339 |
"encode_audios_from_path returned an unexpected number of items."
|
|
@@ -462,7 +462,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 462 |
if length == 0:
|
| 463 |
return f"{audio_start_token}{audio_end_token}"
|
| 464 |
|
| 465 |
-
# step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
|
| 466 |
step_tokens = gen_slot_token * length
|
| 467 |
return f"{audio_start_token}{step_tokens}{audio_end_token}"
|
| 468 |
|
|
@@ -554,6 +553,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 554 |
"""
|
| 555 |
if role == "user":
|
| 556 |
audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
|
|
|
|
| 557 |
else:
|
| 558 |
audio_gen_slot_token = self.audio_assistant_gen_slot_token
|
| 559 |
audio_delay_slot_token = self.audio_assistant_delay_slot_token
|
|
@@ -610,9 +610,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 610 |
):
|
| 611 |
audio_start_idx = int(audio_start_idx_t.item())
|
| 612 |
audio_end_idx = int(audio_end_idx_t.item())
|
| 613 |
-
# delay_audio_codes = self.apply_delay_pattern(
|
| 614 |
-
# audio_codes, self.model_config.audio_pad_code
|
| 615 |
-
# )
|
| 616 |
delay_audio_codes = audio_codes # not delay
|
| 617 |
pad_codes = torch.full(
|
| 618 |
(audio_start_idx - prefix_idx + 1, n_vq),
|
|
@@ -624,10 +621,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 624 |
prefix_idx = audio_end_idx
|
| 625 |
|
| 626 |
if truncation:
|
| 627 |
-
|
| 628 |
-
# : -(n_vq - 1), :
|
| 629 |
-
# ]
|
| 630 |
-
...
|
| 631 |
else:
|
| 632 |
last_audio_end_idx = int(audio_end_indices[-1].item())
|
| 633 |
pad_codes = torch.full(
|
|
@@ -675,8 +669,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 675 |
|
| 676 |
def _parse_audio_codes(self, start_length, audio_codes):
|
| 677 |
# De-delay back to [T', n_vq]
|
| 678 |
-
# audio_codes = self.apply_de_delay_pattern(audio_codes)
|
| 679 |
-
|
| 680 |
# Rows that are all pad are separators between real audio segments.
|
| 681 |
is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
|
| 682 |
non_pad = ~is_pad
|
|
@@ -688,8 +680,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 688 |
if breaks.numel() == 0:
|
| 689 |
segments_idx = [idx]
|
| 690 |
else:
|
| 691 |
-
# assert len(breaks) == 1
|
| 692 |
-
# segments_idx = torch.split(idx, [breaks.tolist()[0], len(idx) - breaks.tolist()[0]])
|
| 693 |
segments_idx = torch.split(idx, breaks.tolist())
|
| 694 |
|
| 695 |
audio_codes_list = [audio_codes[s] for s in segments_idx]
|
|
@@ -906,30 +896,25 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 906 |
codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
|
| 907 |
for codes in audio_tokens_list
|
| 908 |
]
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
max_t =
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
)
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
dec = audio_tokenizer.decode(
|
| 929 |
-
audio_codes, padding_mask=padding_mask, return_dict=True
|
| 930 |
-
)
|
| 931 |
-
audio = dec.audio
|
| 932 |
-
audio_lengths = dec.audio_lengths
|
| 933 |
|
| 934 |
if audio is None or audio_lengths is None:
|
| 935 |
raise RuntimeError(
|
|
|
|
| 88 |
reference = []
|
| 89 |
for speaker_idx, speaker_reference in enumerate(self.reference):
|
| 90 |
if speaker_reference is not None:
|
| 91 |
+
reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}")
|
| 92 |
reference = "\n".join(reference)
|
| 93 |
audio_codes_list = [
|
| 94 |
speaker_reference
|
|
|
|
| 333 |
)
|
| 334 |
|
| 335 |
if len(paths) > 0:
|
| 336 |
+
encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
|
| 337 |
if len(encoded_from_paths) != len(paths):
|
| 338 |
raise RuntimeError(
|
| 339 |
"encode_audios_from_path returned an unexpected number of items."
|
|
|
|
| 462 |
if length == 0:
|
| 463 |
return f"{audio_start_token}{audio_end_token}"
|
| 464 |
|
|
|
|
| 465 |
step_tokens = gen_slot_token * length
|
| 466 |
return f"{audio_start_token}{step_tokens}{audio_end_token}"
|
| 467 |
|
|
|
|
| 553 |
"""
|
| 554 |
if role == "user":
|
| 555 |
audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
|
| 556 |
+
truncation = False
|
| 557 |
else:
|
| 558 |
audio_gen_slot_token = self.audio_assistant_gen_slot_token
|
| 559 |
audio_delay_slot_token = self.audio_assistant_delay_slot_token
|
|
|
|
| 610 |
):
|
| 611 |
audio_start_idx = int(audio_start_idx_t.item())
|
| 612 |
audio_end_idx = int(audio_end_idx_t.item())
|
|
|
|
|
|
|
|
|
|
| 613 |
delay_audio_codes = audio_codes # not delay
|
| 614 |
pad_codes = torch.full(
|
| 615 |
(audio_start_idx - prefix_idx + 1, n_vq),
|
|
|
|
| 621 |
prefix_idx = audio_end_idx
|
| 622 |
|
| 623 |
if truncation:
|
| 624 |
+
raise RuntimeError("Truncation generation is not supported at present")
|
|
|
|
|
|
|
|
|
|
| 625 |
else:
|
| 626 |
last_audio_end_idx = int(audio_end_indices[-1].item())
|
| 627 |
pad_codes = torch.full(
|
|
|
|
| 669 |
|
| 670 |
def _parse_audio_codes(self, start_length, audio_codes):
|
| 671 |
# De-delay back to [T', n_vq]
|
|
|
|
|
|
|
| 672 |
# Rows that are all pad are separators between real audio segments.
|
| 673 |
is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
|
| 674 |
non_pad = ~is_pad
|
|
|
|
| 680 |
if breaks.numel() == 0:
|
| 681 |
segments_idx = [idx]
|
| 682 |
else:
|
|
|
|
|
|
|
| 683 |
segments_idx = torch.split(idx, breaks.tolist())
|
| 684 |
|
| 685 |
audio_codes_list = [audio_codes[s] for s in segments_idx]
|
|
|
|
| 896 |
codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
|
| 897 |
for codes in audio_tokens_list
|
| 898 |
]
|
| 899 |
+
|
| 900 |
+
# Fallback: pad to (NQ, B, T) + mask, then decode.
|
| 901 |
+
nq = int(codes_list[0].shape[0])
|
| 902 |
+
max_t = max(int(c.shape[1]) for c in codes_list)
|
| 903 |
+
audio_codes = torch.zeros(
|
| 904 |
+
nq, len(codes_list), max_t, device=device, dtype=torch.long
|
| 905 |
+
)
|
| 906 |
+
padding_mask = torch.zeros(
|
| 907 |
+
len(codes_list), max_t, device=device, dtype=torch.bool
|
| 908 |
+
)
|
| 909 |
+
for i, c in enumerate(codes_list):
|
| 910 |
+
t = int(c.shape[1])
|
| 911 |
+
audio_codes[:, i, :t] = c
|
| 912 |
+
padding_mask[i, :t] = True
|
| 913 |
+
dec = audio_tokenizer.decode(
|
| 914 |
+
audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
|
| 915 |
+
)
|
| 916 |
+
audio = dec.audio
|
| 917 |
+
audio_lengths = dec.audio_lengths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 918 |
|
| 919 |
if audio is None or audio_lengths is None:
|
| 920 |
raise RuntimeError(
|