YWMditto commited on
Commit
b12b6ca
·
1 Parent(s): f76c185

fix: remove unnecessary information

Browse files
Files changed (1) hide show
  1. processing_moss_tts.py +23 -38
processing_moss_tts.py CHANGED
@@ -88,7 +88,7 @@ class UserMessage(Message):
88
  reference = []
89
  for speaker_idx, speaker_reference in enumerate(self.reference):
90
  if speaker_reference is not None:
91
- reference.append(f"[S{speaker_idx}]:\n{AUDIO_PLACEHOLDER}")
92
  reference = "\n".join(reference)
93
  audio_codes_list = [
94
  speaker_reference
@@ -333,7 +333,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
333
  )
334
 
335
  if len(paths) > 0:
336
- encoded_from_paths = self.encode_audios_from_path(paths, n_vq) # List
337
  if len(encoded_from_paths) != len(paths):
338
  raise RuntimeError(
339
  "encode_audios_from_path returned an unexpected number of items."
@@ -462,7 +462,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
462
  if length == 0:
463
  return f"{audio_start_token}{audio_end_token}"
464
 
465
- # step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1))
466
  step_tokens = gen_slot_token * length
467
  return f"{audio_start_token}{step_tokens}{audio_end_token}"
468
 
@@ -554,6 +553,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
554
  """
555
  if role == "user":
556
  audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
 
557
  else:
558
  audio_gen_slot_token = self.audio_assistant_gen_slot_token
559
  audio_delay_slot_token = self.audio_assistant_delay_slot_token
@@ -610,9 +610,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
610
  ):
611
  audio_start_idx = int(audio_start_idx_t.item())
612
  audio_end_idx = int(audio_end_idx_t.item())
613
- # delay_audio_codes = self.apply_delay_pattern(
614
- # audio_codes, self.model_config.audio_pad_code
615
- # )
616
  delay_audio_codes = audio_codes # not delay
617
  pad_codes = torch.full(
618
  (audio_start_idx - prefix_idx + 1, n_vq),
@@ -624,10 +621,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
624
  prefix_idx = audio_end_idx
625
 
626
  if truncation:
627
- # delay_audio_codes_list[-1] = delay_audio_codes_list[-1][
628
- # : -(n_vq - 1), :
629
- # ]
630
- ...
631
  else:
632
  last_audio_end_idx = int(audio_end_indices[-1].item())
633
  pad_codes = torch.full(
@@ -675,8 +669,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
675
 
676
  def _parse_audio_codes(self, start_length, audio_codes):
677
  # De-delay back to [T', n_vq]
678
- # audio_codes = self.apply_de_delay_pattern(audio_codes)
679
-
680
  # Rows that are all pad are separators between real audio segments.
681
  is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
682
  non_pad = ~is_pad
@@ -688,8 +680,6 @@ class MossTTSDelayProcessor(ProcessorMixin):
688
  if breaks.numel() == 0:
689
  segments_idx = [idx]
690
  else:
691
- # assert len(breaks) == 1
692
- # segments_idx = torch.split(idx, [breaks.tolist()[0], len(idx) - breaks.tolist()[0]])
693
  segments_idx = torch.split(idx, breaks.tolist())
694
 
695
  audio_codes_list = [audio_codes[s] for s in segments_idx]
@@ -906,30 +896,25 @@ class MossTTSDelayProcessor(ProcessorMixin):
906
  codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
907
  for codes in audio_tokens_list
908
  ]
909
-
910
- if hasattr(audio_tokenizer, "batch_decode"):
911
- dec = audio_tokenizer.batch_decode(codes_list)
912
- audio = dec.audio # (B, C, T)
913
- audio_lengths = dec.audio_lengths # (B,)
914
- else:
915
- # Fallback: pad to (NQ, B, T) + mask, then decode.
916
- nq = int(codes_list[0].shape[0])
917
- max_t = max(int(c.shape[1]) for c in codes_list)
918
- audio_codes = torch.zeros(
919
- nq, len(codes_list), max_t, device=device, dtype=torch.long
920
- )
921
- padding_mask = torch.zeros(
922
- len(codes_list), max_t, device=device, dtype=torch.bool
923
- )
924
- for i, c in enumerate(codes_list):
925
- t = int(c.shape[1])
926
- audio_codes[:, i, :t] = c
927
- padding_mask[i, :t] = True
928
- dec = audio_tokenizer.decode(
929
- audio_codes, padding_mask=padding_mask, return_dict=True
930
- )
931
- audio = dec.audio
932
- audio_lengths = dec.audio_lengths
933
 
934
  if audio is None or audio_lengths is None:
935
  raise RuntimeError(
 
88
  reference = []
89
  for speaker_idx, speaker_reference in enumerate(self.reference):
90
  if speaker_reference is not None:
91
+ reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}")
92
  reference = "\n".join(reference)
93
  audio_codes_list = [
94
  speaker_reference
 
333
  )
334
 
335
  if len(paths) > 0:
336
+ encoded_from_paths = self.encode_audios_from_path(paths, n_vq)
337
  if len(encoded_from_paths) != len(paths):
338
  raise RuntimeError(
339
  "encode_audios_from_path returned an unexpected number of items."
 
462
  if length == 0:
463
  return f"{audio_start_token}{audio_end_token}"
464
 
 
465
  step_tokens = gen_slot_token * length
466
  return f"{audio_start_token}{step_tokens}{audio_end_token}"
467
 
 
553
  """
554
  if role == "user":
555
  audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token
556
+ truncation = False
557
  else:
558
  audio_gen_slot_token = self.audio_assistant_gen_slot_token
559
  audio_delay_slot_token = self.audio_assistant_delay_slot_token
 
610
  ):
611
  audio_start_idx = int(audio_start_idx_t.item())
612
  audio_end_idx = int(audio_end_idx_t.item())
 
 
 
613
  delay_audio_codes = audio_codes # not delay
614
  pad_codes = torch.full(
615
  (audio_start_idx - prefix_idx + 1, n_vq),
 
621
  prefix_idx = audio_end_idx
622
 
623
  if truncation:
624
+ raise RuntimeError("Truncation generation is not supported at present")
 
 
 
625
  else:
626
  last_audio_end_idx = int(audio_end_indices[-1].item())
627
  pad_codes = torch.full(
 
669
 
670
  def _parse_audio_codes(self, start_length, audio_codes):
671
  # De-delay back to [T', n_vq]
 
 
672
  # Rows that are all pad are separators between real audio segments.
673
  is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1)
674
  non_pad = ~is_pad
 
680
  if breaks.numel() == 0:
681
  segments_idx = [idx]
682
  else:
 
 
683
  segments_idx = torch.split(idx, breaks.tolist())
684
 
685
  audio_codes_list = [audio_codes[s] for s in segments_idx]
 
896
  codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
897
  for codes in audio_tokens_list
898
  ]
899
+
900
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
901
+ nq = int(codes_list[0].shape[0])
902
+ max_t = max(int(c.shape[1]) for c in codes_list)
903
+ audio_codes = torch.zeros(
904
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
905
+ )
906
+ padding_mask = torch.zeros(
907
+ len(codes_list), max_t, device=device, dtype=torch.bool
908
+ )
909
+ for i, c in enumerate(codes_list):
910
+ t = int(c.shape[1])
911
+ audio_codes[:, i, :t] = c
912
+ padding_mask[i, :t] = True
913
+ dec = audio_tokenizer.decode(
914
+ audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
915
+ )
916
+ audio = dec.audio
917
+ audio_lengths = dec.audio_lengths
 
 
 
 
 
918
 
919
  if audio is None or audio_lengths is None:
920
  raise RuntimeError(