Spaces:
Runtime error
Runtime error
| from torch import nn | |
| import torch | |
| import copy | |
| from pathlib import Path | |
| from torchaudio.models import Conformer | |
| from f5_tts.model.utils import default | |
| from f5_tts.model.utils import exists | |
| from f5_tts.model.utils import list_str_to_idx | |
| from f5_tts.model.utils import list_str_to_tensor | |
| from f5_tts.model.utils import lens_to_mask | |
| from f5_tts.model.utils import mask_from_frac_lengths | |
| from f5_tts.model.utils import ( | |
| default, | |
| exists, | |
| list_str_to_idx, | |
| list_str_to_tensor, | |
| lens_to_mask, | |
| mask_from_frac_lengths, | |
| ) | |
| class ResBlock(nn.Module): | |
| def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2): | |
| super().__init__() | |
| self._n_groups = 8 | |
| self.blocks = nn.ModuleList([ | |
| self._get_conv(hidden_dim, dilation=3**i, dropout_p=dropout_p) | |
| for i in range(n_conv)]) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| res = x | |
| x = block(x) | |
| x += res | |
| return x | |
| def _get_conv(self, hidden_dim, dilation, dropout_p=0.2): | |
| layers = [ | |
| nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), | |
| nn.ReLU(), | |
| nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), | |
| nn.Dropout(p=dropout_p), | |
| nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), | |
| nn.ReLU(), | |
| nn.Dropout(p=dropout_p) | |
| ] | |
| return nn.Sequential(*layers) | |
| class ConformerCTC(nn.Module): | |
| def __init__(self, | |
| vocab_size, | |
| mel_dim=100, | |
| num_heads=8, | |
| d_hid=512, | |
| nlayers=6): | |
| super().__init__() | |
| self.mel_proj = nn.Conv1d(mel_dim, d_hid, kernel_size=3, padding=1) | |
| self.d_hid = d_hid | |
| self.resblock1 = nn.Sequential( | |
| ResBlock(d_hid), | |
| nn.GroupNorm(num_groups=1, num_channels=d_hid) | |
| ) | |
| self.resblock2 = nn.Sequential( | |
| ResBlock(d_hid), | |
| nn.GroupNorm(num_groups=1, num_channels=d_hid) | |
| ) | |
| self.conf_pre = torch.nn.ModuleList( | |
| [Conformer( | |
| input_dim=d_hid, | |
| num_heads=num_heads, | |
| ffn_dim=d_hid * 2, | |
| num_layers=1, | |
| depthwise_conv_kernel_size=15, | |
| use_group_norm=True,) | |
| for _ in range(nlayers // 2) | |
| ] | |
| ) | |
| self.conf_after = torch.nn.ModuleList( | |
| [Conformer( | |
| input_dim=d_hid, | |
| num_heads=num_heads, | |
| ffn_dim=d_hid * 2, | |
| num_layers=1, | |
| depthwise_conv_kernel_size=7, | |
| use_group_norm=True,) | |
| for _ in range(nlayers // 2) | |
| ] | |
| ) | |
| self.out = nn.Linear(d_hid, 1 + vocab_size) # 1 for blank | |
| self.ctc_loss = nn.CTCLoss(blank=vocab_size, zero_infinity=True).cuda() | |
| def forward(self, latent, text=None, text_lens=None): | |
| layers = [] | |
| x = self.mel_proj(latent.transpose(-1, -2)).transpose(-1, -2) | |
| x = x.transpose(1, 2) | |
| layers.append(nn.functional.avg_pool1d(x, 4)) | |
| # x = x.transpose(1, 2) | |
| x = self.resblock1(x) | |
| x = nn.functional.avg_pool1d(x, 2) | |
| layers.append(nn.functional.avg_pool1d(x, 2)) | |
| x = self.resblock2(x) | |
| x = nn.functional.avg_pool1d(x, 2) | |
| layers.append(x) | |
| x = x.transpose(1, 2) | |
| batch_size, time_steps, _ = x.shape | |
| # Create a dummy lengths tensor (all sequences are assumed to be full length). | |
| input_lengths = torch.full((batch_size,), time_steps, device=x.device, dtype=torch.int64) | |
| for layer in (self.conf_pre): | |
| x, _ = layer(x, input_lengths) | |
| layers.append(x.transpose(1, 2)) | |
| for layer in (self.conf_after): | |
| x, _ = layer(x, input_lengths) | |
| layers.append(x.transpose(1, 2)) | |
| x = self.out(x) | |
| if text_lens is not None and text is not None: | |
| loss = self.ctc_loss(x.log_softmax(dim=2).transpose(0, 1), text, input_lengths, text_lens) | |
| return x, layers, loss | |
| else: | |
| return x, layers | |
| if __name__ == "__main__": | |
| from f5_tts.model.utils import get_tokenizer | |
| bsz = 16 | |
| tokenizer = "pinyin" # 'pinyin', 'char', or 'custom' | |
| tokenizer_path = None # if tokenizer = 'custom', define the path to the tokenizer you want to use (should be vocab.txt) | |
| dataset_name = "Emilia_ZH_EN" | |
| if tokenizer == "custom": | |
| tokenizer_path = tokenizer_path | |
| else: | |
| tokenizer_path = dataset_name | |
| vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer) | |
| model = ConformerCTC(vocab_size, mel_dim=80, num_heads=8, d_hid=512, nlayers=6).cuda() | |
| text = ["hello world"] * bsz | |
| lens = torch.randint(1, 1000, (bsz,)).cuda() | |
| inp = torch.randn(bsz, lens.max(), 80).cuda() | |
| batch, seq_len, dtype, device = *inp.shape[:2], inp.dtype, inp.device | |
| # handle text as string | |
| text_lens = torch.tensor([len(t) for t in text], device=device) | |
| if isinstance(text, list): | |
| if exists(vocab_char_map): | |
| text = list_str_to_idx(text, vocab_char_map).to(device) | |
| else: | |
| text = list_str_to_tensor(text).to(device) | |
| assert text.shape[0] == batch | |
| # lens and mask | |
| if not exists(lens): | |
| lens = torch.full((batch,), seq_len, device=device) | |
| out, layers, loss = model(inp, text_lens) | |
| print(out.shape) | |
| print(out) | |
| print(len(layers)) | |
| print(torch.stack(layers, axis=1).shape) | |
| print(loss) | |
| probs = out.softmax(dim=2) # Convert logits to probabilities | |
| # Greedy decoding | |
| best_path = torch.argmax(probs, dim=2) | |
| decoded_sequences = [] | |
| blank_idx = vocab_size | |
| char_vocab_map = list(vocab_char_map.keys()) | |
| for batch in best_path: | |
| decoded_sequence = [] | |
| previous_token = None | |
| for token in batch: | |
| if token != previous_token: # Collapse repeated tokens | |
| if token != blank_idx: # Ignore blank tokens | |
| decoded_sequence.append(token.item()) | |
| previous_token = token | |
| decoded_sequences.append(decoded_sequence) | |
| # Convert token indices to characters | |
| decoded_texts = [''.join([char_vocab_map[token] for token in sequence]) for sequence in decoded_sequences] | |
| gt_texts = [] | |
| for i in range(text_lens.size(0)): | |
| gt_texts.append(''.join([char_vocab_map[token] for token in text[i, :text_lens[i]]])) | |
| print(decoded_texts) | |
| print(gt_texts) |