|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
import torch.nn as nn |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from torchaudio.compliance.kaldi import fbank as torch_fbank |
|
|
|
|
|
from .configuration_spear import SpearConfig |
|
|
from .zipformer import Zipformer2, Conv2dSubsampling |
|
|
|
|
|
LOG_EPS=math.log(1e-10) |
|
|
SAMPLING_RATE=16000 |
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
lengths: |
|
|
A 1-D tensor containing sentence lengths. |
|
|
max_len: |
|
|
The length of masks. |
|
|
Returns: |
|
|
Return a 2-D bool tensor, where masked positions |
|
|
are filled with `True` and non-masked positions are |
|
|
filled with `False`. |
|
|
|
|
|
This function is borrowed from https://github.com/k2-fsa/icefall |
|
|
|
|
|
>>> lengths = torch.tensor([1, 3, 2, 5]) |
|
|
>>> make_pad_mask(lengths) |
|
|
tensor([[False, True, True, True, True], |
|
|
[False, False, False, True, True], |
|
|
[False, False, True, True, True], |
|
|
[False, False, False, False, False]]) |
|
|
""" |
|
|
assert lengths.ndim == 1, lengths.ndim |
|
|
max_len = max(max_len, lengths.max()) |
|
|
n = lengths.size(0) |
|
|
seq_range = torch.arange(0, max_len, device=lengths.device) |
|
|
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) |
|
|
|
|
|
return expaned_lengths >= lengths.unsqueeze(-1) |
|
|
|
|
|
def get_model(config: SpearConfig) -> nn.Module: |
|
|
encoder_embed = get_encoder_embed(config) |
|
|
encoder = get_encoder_model(config) |
|
|
|
|
|
model = SpearEncoder( |
|
|
encoder_embed=encoder_embed, |
|
|
encoder=encoder, |
|
|
encoder_dim=max(_to_int_tuple(config.encoder_dim)), |
|
|
num_codebooks=0, |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
class SpearModel(nn.Module): |
|
|
def __init__( |
|
|
self, config: SpearConfig, |
|
|
): |
|
|
super().__init__() |
|
|
model = get_model(config) |
|
|
self.config = config |
|
|
self.model = model |
|
|
|
|
|
def _load_audio_single(self, audio_path: str) -> Tuple[torch.Tensor, int]: |
|
|
waveform, sr = torchaudio.load(audio_path) |
|
|
if waveform.size(0) > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
if sr != SAMPLING_RATE: |
|
|
transform = torchaudio.transforms.Resample(sr, SAMPLING_RATE) |
|
|
waveform = transform(waveform) |
|
|
waveform_len = waveform.shape[-1] |
|
|
return waveform, waveform_len |
|
|
|
|
|
def load_audio(self, audio_paths: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
assert isinstance(audio_paths, list), "Must receive a list of files for reading" |
|
|
waveforms = [] |
|
|
waveform_lens = [] |
|
|
for audio in audio_paths: |
|
|
wav, wav_len = self._load_audio_single(audio) |
|
|
waveforms.append(wav.squeeze()) |
|
|
waveform_lens.append(wav_len) |
|
|
|
|
|
waveforms = pad_sequence(waveforms, batch_first=True) |
|
|
waveform_lens = torch.tensor(waveform_lens) |
|
|
return waveforms, waveform_lens |
|
|
|
|
|
def compute_fbank( |
|
|
self, wavs: torch.Tensor, wav_lens: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute fbank features |
|
|
|
|
|
Args: |
|
|
wavs (torch.Tensor): the mono-channel input waveform, (N, T) |
|
|
wav_lens (torch.Tensor): the length of each waveform in samples (N) |
|
|
|
|
|
Returns: |
|
|
The fbank features, and their lengths |
|
|
""" |
|
|
assert wavs.ndim == 2, wavs.shape |
|
|
low_freq = 20.0 |
|
|
high_freq=-400.0 |
|
|
dither=0.0 |
|
|
snip_egdes=False |
|
|
|
|
|
features = [] |
|
|
for i, wav in enumerate(wavs): |
|
|
feat = torch_fbank( |
|
|
wav[:wav_lens[i]].unsqueeze(0), |
|
|
sample_frequency=16000, |
|
|
num_mel_bins=128, |
|
|
low_freq=low_freq, |
|
|
snip_edges=snip_egdes, |
|
|
high_freq=high_freq, |
|
|
dither=dither, |
|
|
energy_floor=1.0e-10, |
|
|
) |
|
|
features.append(feat) |
|
|
feat_len = torch.tensor([f.shape[0] for f in features]).to(wavs.device) |
|
|
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS).to(wavs.device) |
|
|
return features, feat_len |
|
|
|
|
|
|
|
|
def forward(self, audio: torch.Tensor, audio_lens: torch.Tensor, return_middle_layers: bool = True): |
|
|
"""Encode a batch of audio |
|
|
|
|
|
Args: |
|
|
audio (torch.Tensor): Input audio waveforms (N,L) |
|
|
audio_lens (torch.Tensor): The length of the audio waveforms (N) |
|
|
return_middle_layers (bool, optional): Output the intermediate features. |
|
|
|
|
|
Returns: |
|
|
The encoded representations, and the length of each representation (N,T,C), (N) |
|
|
""" |
|
|
|
|
|
|
|
|
x, x_lens = self.compute_fbank(audio, audio_lens) |
|
|
outputs = self.model.forward_encoder( |
|
|
x=x, |
|
|
x_lens=x_lens, |
|
|
return_middle_out=return_middle_layers, |
|
|
return_dict=True, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
|
|
|
class SpearEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
encoder_embed: nn.Module, |
|
|
encoder: nn.Module, |
|
|
encoder_dim: int, |
|
|
num_codebooks: int=8, |
|
|
distillation_layer: int=9, |
|
|
distillation_delta: int=0, |
|
|
teacher_frame_ratio: int = 2, |
|
|
interpolate_teacher: bool = False, |
|
|
n_mels: int = 128, |
|
|
mask_mode: str = "w2v2", |
|
|
mask_prob: float = 0.65, |
|
|
mask_length: int = 10, |
|
|
mask_selection: str = "static", |
|
|
mask_other: float = 0.0, |
|
|
min_masks: int = 2, |
|
|
mask_channel_prob: float = 0.0, |
|
|
mask_channel_length: int = 10, |
|
|
mask_channel_selection: str = "static", |
|
|
mask_channel_other: float = 0.0, |
|
|
loss_only_mask: bool = False, |
|
|
): |
|
|
"""A model that performs MVQ KD pre-training . |
|
|
|
|
|
Args: |
|
|
encoder_embed: |
|
|
It is a Convolutional 2D subsampling module. It converts |
|
|
an input of shape (N, T, idim) to an output of of shape |
|
|
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. |
|
|
encoder: |
|
|
It is the transcription network in the paper. Its accepts |
|
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). |
|
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and |
|
|
`logit_lens` of shape (N,). |
|
|
num_codebooks: |
|
|
The number of codebooks used in the target |
|
|
distillation_layer: |
|
|
Use which layer to do MVQ pre-training |
|
|
distillation_delta: |
|
|
How many frames to delay the alignment between the model and the target frames. |
|
|
Should be zero for non-streaming models, and a positive number for streaming models |
|
|
teacher_frame_ratio: |
|
|
The frame rate ratio between the target and the model output |
|
|
mask_mode: |
|
|
The masking mode. |
|
|
w2v2: the wav2vec2 style of masking, allows overlap |
|
|
custom: no overlap, therefore bigger masking ratio |
|
|
mask_prob: |
|
|
The probability of selecting choosing one frame as the start index |
|
|
mask_length: |
|
|
The length of each mask |
|
|
mask_selection: |
|
|
How to determine the length of the mask, see ``compute_mask_indices'' |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.encoder_embed = encoder_embed |
|
|
self.encoder = encoder |
|
|
self.encoder_dim = encoder_dim |
|
|
|
|
|
self.distillation_layer = distillation_layer |
|
|
|
|
|
|
|
|
|
|
|
self.num_codebooks= num_codebooks |
|
|
self.teacher_frame_ratio = teacher_frame_ratio |
|
|
self.interpolate_teacher = interpolate_teacher |
|
|
self.distillation_delta = distillation_delta |
|
|
|
|
|
if num_codebooks > 0: |
|
|
from .spear_modules import JointCodebookLoss |
|
|
self.codebook_loss_net = JointCodebookLoss( |
|
|
input_dim=encoder_dim, |
|
|
num_codebooks=num_codebooks * self.teacher_frame_ratio, |
|
|
reduction="none", |
|
|
) |
|
|
else: |
|
|
self.codebook_loss_net = None |
|
|
|
|
|
|
|
|
assert mask_mode in ["w2v2", "block"], f"Unseen mask mode: {mask_mode}" |
|
|
self.mask_mode = mask_mode |
|
|
|
|
|
self.mask_emb = nn.Parameter(torch.FloatTensor(n_mels).normal_()) |
|
|
self.mask_prob = mask_prob |
|
|
self.mask_length = mask_length |
|
|
self.mask_selection = mask_selection |
|
|
self.mask_other = mask_other |
|
|
self.min_masks = min_masks |
|
|
|
|
|
self.mask_channel_prob = mask_channel_prob |
|
|
self.mask_channel_length = mask_channel_length |
|
|
self.mask_channel_selection = mask_channel_selection |
|
|
self.mask_channel_other = mask_channel_other |
|
|
|
|
|
self.loss_only_mask = loss_only_mask |
|
|
|
|
|
def forward_encoder( |
|
|
self, x: torch.Tensor, x_lens: torch.Tensor, return_middle_out: bool = False, return_dict: bool = False, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute encoder outputs. |
|
|
Args: |
|
|
x: |
|
|
A 3-D tensor of shape (N, T, C). |
|
|
x_lens: |
|
|
A 1-D tensor of shape (N,). It contains the number of frames in `x` |
|
|
before padding. |
|
|
|
|
|
Returns: |
|
|
encoder_out: |
|
|
Encoder output, of shape (N, T, C). |
|
|
encoder_out_lens: |
|
|
Encoder output lengths, of shape (N,). |
|
|
""" |
|
|
|
|
|
x, x_lens = self.encoder_embed(x, x_lens) |
|
|
|
|
|
|
|
|
src_key_padding_mask = make_pad_mask(x_lens) |
|
|
x = x.permute(1, 0, 2) |
|
|
|
|
|
encoder_out, encoder_out_lens, middle_out = self.encoder(x, x_lens, src_key_padding_mask, return_middle_out=True) |
|
|
middle_out = [feat.permute(1,0,2) for feat in middle_out] |
|
|
|
|
|
encoder_out = encoder_out.permute(1, 0, 2) |
|
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) |
|
|
|
|
|
if not return_dict: |
|
|
return encoder_out, encoder_out_lens, middle_out |
|
|
else: |
|
|
outputs = { |
|
|
"encoder_out": encoder_out, |
|
|
"encoder_out_lens": encoder_out_lens, |
|
|
"hidden_states": middle_out, |
|
|
} |
|
|
return outputs |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
x_lens: torch.Tensor, |
|
|
codebook_indexes: torch.Tensor = None, |
|
|
mask: bool = True, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x: |
|
|
A 3-D tensor of shape (N, T, C). |
|
|
x_lens: |
|
|
A 1-D tensor of shape (N,). It contains the number of frames in `x` |
|
|
before padding. |
|
|
codebook_indexes: |
|
|
Codebook indexes of teacher embeddings |
|
|
mask: |
|
|
If we perform w2v2 style of masking over the fbank frames |
|
|
|
|
|
Returns: |
|
|
Return the codebook loss |
|
|
""" |
|
|
assert x.ndim == 3, x.shape |
|
|
assert x_lens.ndim == 1, x_lens.shape |
|
|
assert codebook_indexes is not None |
|
|
|
|
|
|
|
|
if self.training and mask: |
|
|
padding_mask = make_pad_mask(x_lens) |
|
|
|
|
|
|
|
|
x, mask_indices = self.apply_mask( |
|
|
x.clone(), |
|
|
padding_mask=padding_mask |
|
|
) |
|
|
else: |
|
|
mask_indices = None |
|
|
|
|
|
|
|
|
encoder_out, encoder_out_lens, _ = self.forward_encoder(x, x_lens) |
|
|
|
|
|
|
|
|
if codebook_indexes is not None and self.codebook_loss_net is not None: |
|
|
codebook_loss = self.forward_codebook_loss( |
|
|
encoder_out, encoder_out_lens, codebook_indexes, reduction="none" |
|
|
) |
|
|
if self.loss_only_mask and mask_indices is not None: |
|
|
|
|
|
mask_indices = nn.functional.avg_pool1d(mask_indices, 4) >= 0.5 |
|
|
assert mask_indices.size(1) >= codebook_loss.size(1) |
|
|
mask_indices = mask_indices[:, :codebook_loss.size(1)].float() |
|
|
codebook_loss = codebook_loss * mask_indices |
|
|
codebook_loss = codebook_loss.sum(dim=1) |
|
|
else: |
|
|
codebook_loss = None |
|
|
|
|
|
return codebook_loss |
|
|
|
|
|
def forward_codebook_loss( |
|
|
self, |
|
|
encoder_out: torch.Tensor, |
|
|
encoder_out_lens: torch.Tensor, |
|
|
codebook_indexes: torch.Tensor, |
|
|
reduction: str = "sum", |
|
|
): |
|
|
|
|
|
if self.interpolate_teacher: |
|
|
codebook_indexes = self.interpolate_codebook_indexes( |
|
|
encoder_out, codebook_indexes |
|
|
) |
|
|
else: |
|
|
if codebook_indexes.shape[1] != encoder_out.shape[1]: |
|
|
|
|
|
codebook_indexes = self.concat_successive_codebook_indexes( |
|
|
encoder_out, codebook_indexes, ratio=self.teacher_frame_ratio |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.distillation_delta > 0: |
|
|
codebook_indexes = codebook_indexes[:,:-self.distillation_delta, :] |
|
|
encoder_out = encoder_out[:, self.distillation_delta:, :] |
|
|
truncated_padding_mask = make_pad_mask(encoder_out_lens - self.distillation_delta) |
|
|
codebook_indexes = codebook_indexes.masked_fill(truncated_padding_mask.unsqueeze(-1), value=-100) |
|
|
|
|
|
N,T,_ = encoder_out.shape |
|
|
codebook_loss = self.codebook_loss_net(encoder_out.float(), codebook_indexes) |
|
|
codebook_loss = codebook_loss.reshape(N,T,-1) |
|
|
num_cb = codebook_loss.size(-1) |
|
|
|
|
|
if reduction == "sum": |
|
|
codebook_loss = codebook_loss.sum(dim=(1,2)) / num_cb |
|
|
elif reduction == "none": |
|
|
codebook_loss = codebook_loss.sum(dim=2) / num_cb |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
return codebook_loss |
|
|
|
|
|
def apply_mask( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
padding_mask: torch.Tensor = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Apply mask according to the mask_mode, return the masked features and the masked positions |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): The input fbank features |
|
|
padding_mask (torch.Tensor, optional): The padding mask |
|
|
|
|
|
Returns: |
|
|
The masked fbank feature and the masked_indices, with masked positions as 1 |
|
|
""" |
|
|
|
|
|
if self.mask_mode == "w2v2": |
|
|
x, masked_indices = self.apply_mask_w2v2(x, padding_mask) |
|
|
elif self.mask_mode == "block": |
|
|
x, masked_indices = self.apply_mask_block(x, padding_mask) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
if random.random() > 0.97: |
|
|
logging.info(f"Apply {self.mask_mode} masking. A proportion of {masked_indices.sum()/masked_indices.numel():.2f} frames are masked") |
|
|
return x, masked_indices |
|
|
|
|
|
|
|
|
def apply_mask_block( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
padding_mask: torch.Tensor = None |
|
|
): |
|
|
B,T,C = x.shape |
|
|
assert self.mask_prob > 0.0 |
|
|
|
|
|
mask_indices = compute_mask_indices_block( |
|
|
shape=(B,T), |
|
|
padding_mask=padding_mask, |
|
|
mask_prob=self.mask_prob, |
|
|
mask_length=self.mask_length, |
|
|
min_masks=self.min_masks, |
|
|
).to(x.device) |
|
|
|
|
|
x = index_put(x, mask_indices.bool(), self.mask_emb) |
|
|
|
|
|
return x, mask_indices |
|
|
|
|
|
def apply_mask_w2v2( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
padding_mask: torch.Tensor = None |
|
|
): |
|
|
|
|
|
|
|
|
B, T, C = x.shape |
|
|
|
|
|
|
|
|
if self.mask_channel_prob > 0: |
|
|
mask_channel_indices = compute_mask_indices( |
|
|
(B, C), |
|
|
None, |
|
|
self.mask_channel_prob, |
|
|
self.mask_channel_length, |
|
|
self.mask_channel_selection, |
|
|
self.mask_channel_other, |
|
|
no_overlap=False, |
|
|
min_space=1, |
|
|
require_same_masks=False, |
|
|
) |
|
|
mask_channel_indices = ( |
|
|
torch.from_numpy(mask_channel_indices) |
|
|
.to(x.device) |
|
|
.unsqueeze(1) |
|
|
.expand(-1, T, -1) |
|
|
) |
|
|
if random.random() > 0.98: |
|
|
logging.info(f"A proportion of {mask_channel_indices.sum()/mask_channel_indices.numel():.2f} feature dims are masked") |
|
|
x[mask_channel_indices] = 0 |
|
|
|
|
|
if self.mask_prob > 0: |
|
|
mask_indices = compute_mask_indices( |
|
|
(B, T), |
|
|
padding_mask, |
|
|
self.mask_prob, |
|
|
self.mask_length, |
|
|
mask_type=self.mask_selection, |
|
|
mask_other=self.mask_other, |
|
|
min_masks=2, |
|
|
no_overlap=False, |
|
|
min_space=1, |
|
|
require_same_masks=False, |
|
|
) |
|
|
mask_indices = torch.from_numpy(mask_indices).to(x.device) |
|
|
x = index_put(x, mask_indices, self.mask_emb) |
|
|
mask_indices = mask_indices.float() |
|
|
else: |
|
|
mask_indices = None |
|
|
|
|
|
return x, mask_indices |
|
|
|
|
|
@staticmethod |
|
|
def interpolate_codebook_indexes(middle_layer_output, codebook_indexes): |
|
|
|
|
|
|
|
|
t_expected = middle_layer_output.shape[1] |
|
|
N, T, C = codebook_indexes.shape |
|
|
|
|
|
codebook_indexes = codebook_indexes.permute(0,2,1).float() |
|
|
codebook_indexes = torch.nn.functional.interpolate(codebook_indexes, t_expected) |
|
|
codebook_indexes = codebook_indexes.permute(0,2,1).int() |
|
|
|
|
|
assert codebook_indexes.shape[1] == middle_layer_output.shape[1] |
|
|
return codebook_indexes |
|
|
|
|
|
@staticmethod |
|
|
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes, ratio=2): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_expected = middle_layer_output.shape[1] |
|
|
N, T, C = codebook_indexes.shape |
|
|
|
|
|
|
|
|
if T >= t_expected * ratio: |
|
|
codebook_indexes = codebook_indexes[:, : t_expected * ratio, :] |
|
|
else: |
|
|
assert t_expected * ratio - T <= 5, (T, t_expected, ratio) |
|
|
diff = t_expected * ratio - T |
|
|
codebook_indexes = torch.cat( |
|
|
[ |
|
|
codebook_indexes, |
|
|
torch.full((N,diff,C), -100).to(codebook_indexes.device).to(codebook_indexes.dtype) |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
assert codebook_indexes.size(1) == middle_layer_output.size(1) * ratio |
|
|
|
|
|
|
|
|
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * ratio) |
|
|
assert middle_layer_output.shape[1] == codebook_indexes.shape[1] |
|
|
return codebook_indexes |
|
|
|
|
|
def index_put(tensor, indices, value): |
|
|
tensor[indices] = value |
|
|
return tensor |
|
|
|
|
|
def compute_mask_indices_block( |
|
|
shape, |
|
|
padding_mask, |
|
|
mask_prob: float = 0.5, |
|
|
mask_length: int = 10, |
|
|
min_masks: int = 2, |
|
|
): |
|
|
|
|
|
B,T = shape |
|
|
mask_indices = [] |
|
|
for i in range(B): |
|
|
if padding_mask is not None: |
|
|
num_segments = (T - padding_mask[i].sum()) // mask_length |
|
|
else: |
|
|
num_segments = T // mask_length |
|
|
segment_mask = torch.rand(num_segments) < mask_prob |
|
|
while sum(segment_mask) < min_masks: |
|
|
segment_mask = torch.rand(num_segments) < mask_prob |
|
|
segment_mask_expanded = segment_mask.unsqueeze(-1).expand(num_segments, mask_length) |
|
|
segment_mask_expanded = segment_mask_expanded.reshape(-1).float() |
|
|
if segment_mask_expanded.size(0) < T: |
|
|
pad = T - segment_mask_expanded.size(0) |
|
|
segment_mask_expanded = torch.cat([segment_mask_expanded, torch.zeros(pad)]) |
|
|
mask_indices.append(segment_mask_expanded) |
|
|
|
|
|
mask_indices = torch.stack(mask_indices) |
|
|
return mask_indices |
|
|
|
|
|
def compute_mask_indices( |
|
|
shape: Tuple[int, int], |
|
|
padding_mask: Optional[torch.Tensor], |
|
|
mask_prob: float, |
|
|
mask_length: int, |
|
|
mask_type: str = "static", |
|
|
mask_other: float = 0.0, |
|
|
min_masks: int = 0, |
|
|
no_overlap: bool = False, |
|
|
min_space: int = 0, |
|
|
require_same_masks: bool = True, |
|
|
mask_dropout: float = 0.0, |
|
|
add_masks: bool = False, |
|
|
seed: Optional[int] = None, |
|
|
epoch: Optional[int] = None, |
|
|
indices: Optional[torch.Tensor] = None, |
|
|
idc_select_ver: int = 1, |
|
|
num_mask_ver: int = 2, |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Computes random mask spans for a given shape |
|
|
|
|
|
Args: |
|
|
shape: the the shape for which to compute masks. |
|
|
should be of size 2 where first element is batch size and 2nd is timesteps |
|
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements |
|
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by |
|
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements. |
|
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True) |
|
|
mask_type: how to compute mask lengths |
|
|
static = fixed size |
|
|
uniform = sample from uniform distribution [mask_other, mask_length*2] |
|
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element |
|
|
poisson = sample from possion distribution with lambda = mask length |
|
|
min_masks: minimum number of masked spans |
|
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping |
|
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans |
|
|
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample |
|
|
mask_dropout: randomly dropout this percentage of masks in each example |
|
|
""" |
|
|
|
|
|
bsz, all_sz = shape |
|
|
mask = np.full((bsz, all_sz), False) |
|
|
|
|
|
if num_mask_ver == 1: |
|
|
all_num_mask = int( |
|
|
|
|
|
mask_prob * all_sz / float(mask_length) |
|
|
+ np.random.rand() |
|
|
) |
|
|
all_num_mask = max(min_masks, all_num_mask) |
|
|
|
|
|
mask_idcs = [] |
|
|
for i in range(bsz): |
|
|
if seed is not None and epoch is not None and indices is not None: |
|
|
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) |
|
|
else: |
|
|
seed_i = None |
|
|
|
|
|
rng = np.random.default_rng(seed_i) |
|
|
|
|
|
if padding_mask is not None: |
|
|
sz = all_sz - padding_mask[i].long().sum().item() |
|
|
assert sz >= 0, sz |
|
|
else: |
|
|
sz = all_sz |
|
|
|
|
|
if num_mask_ver == 1: |
|
|
if padding_mask is not None: |
|
|
num_mask = int( |
|
|
|
|
|
mask_prob * sz / float(mask_length) |
|
|
+ np.random.rand() |
|
|
) |
|
|
num_mask = max(min_masks, num_mask) |
|
|
else: |
|
|
num_mask = all_num_mask |
|
|
elif num_mask_ver == 2: |
|
|
num_mask = int( |
|
|
|
|
|
mask_prob * sz / float(mask_length) |
|
|
+ rng.random() |
|
|
) |
|
|
num_mask = max(min_masks, num_mask) |
|
|
hard_max = sz // mask_length |
|
|
num_mask = min(hard_max, num_mask) |
|
|
else: |
|
|
raise ValueError() |
|
|
|
|
|
if mask_type == "static": |
|
|
lengths = np.full(num_mask, mask_length) |
|
|
elif mask_type == "uniform": |
|
|
lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) |
|
|
elif mask_type == "normal": |
|
|
lengths = rng.normal(mask_length, mask_other, size=num_mask) |
|
|
lengths = [max(1, int(round(x))) for x in lengths] |
|
|
elif mask_type == "poisson": |
|
|
lengths = rng.poisson(mask_length, size=num_mask) |
|
|
lengths = [int(round(x)) for x in lengths] |
|
|
else: |
|
|
raise Exception("unknown mask selection " + mask_type) |
|
|
|
|
|
if sum(lengths) == 0: |
|
|
if mask_type == "static": |
|
|
raise ValueError("this should never happens") |
|
|
else: |
|
|
lengths = [min(mask_length, sz - 1)] |
|
|
|
|
|
if no_overlap: |
|
|
mask_idc = [] |
|
|
|
|
|
def arrange(s, e, length, keep_length): |
|
|
span_start = rng.randint(s, e - length) |
|
|
mask_idc.extend(span_start + i for i in range(length)) |
|
|
|
|
|
new_parts = [] |
|
|
if span_start - s - min_space >= keep_length: |
|
|
new_parts.append((s, span_start - min_space + 1)) |
|
|
if e - span_start - length - min_space > keep_length: |
|
|
new_parts.append((span_start + length + min_space, e)) |
|
|
return new_parts |
|
|
|
|
|
parts = [(0, sz)] |
|
|
min_length = min(lengths) |
|
|
for length in sorted(lengths, reverse=True): |
|
|
lens = np.fromiter( |
|
|
(e - s if e - s >= length + min_space else 0 for s, e in parts), |
|
|
np.int, |
|
|
) |
|
|
l_sum = np.sum(lens) |
|
|
if l_sum == 0: |
|
|
break |
|
|
probs = lens / np.sum(lens) |
|
|
c = rng.choice(len(parts), p=probs) |
|
|
s, e = parts.pop(c) |
|
|
parts.extend(arrange(s, e, length, min_length)) |
|
|
mask_idc = np.asarray(mask_idc) |
|
|
else: |
|
|
if idc_select_ver == 1: |
|
|
min_len = min(lengths) |
|
|
if sz - min_len <= num_mask: |
|
|
min_len = sz - num_mask - 1 |
|
|
mask_idc = rng.choice(sz - min_len, num_mask, replace=False) |
|
|
elif idc_select_ver == 2: |
|
|
mask_idc = rng.choice(sz, num_mask, replace=False) |
|
|
else: |
|
|
raise ValueError() |
|
|
|
|
|
mask_idc = np.asarray( |
|
|
[ |
|
|
mask_idc[j] + offset |
|
|
for j in range(len(mask_idc)) |
|
|
for offset in range(lengths[j]) |
|
|
] |
|
|
) |
|
|
|
|
|
mask_idc = np.unique(mask_idc[mask_idc < sz]) |
|
|
if len(mask_idc) >= sz: |
|
|
|
|
|
raise ValueError( |
|
|
( |
|
|
f"the entire sequence is masked. " |
|
|
f"sz={sz}; mask_idc[mask_idc]; " |
|
|
f"index={indices[i] if indices is not None else None}" |
|
|
) |
|
|
) |
|
|
mask_idcs.append(mask_idc) |
|
|
|
|
|
target_len = None |
|
|
if require_same_masks: |
|
|
if add_masks: |
|
|
target_len = max([len(m) for m in mask_idcs]) |
|
|
else: |
|
|
target_len = min([len(m) for m in mask_idcs]) |
|
|
|
|
|
for i, mask_idc in enumerate(mask_idcs): |
|
|
if target_len is not None and len(mask_idc) > target_len: |
|
|
mask_idc = rng.choice(mask_idc, target_len, replace=False) |
|
|
|
|
|
mask[i, mask_idc] = True |
|
|
|
|
|
if target_len is not None and len(mask_idc) < target_len: |
|
|
unmasked = np.flatnonzero(~mask[i]) |
|
|
to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) |
|
|
mask[i, to_mask] = True |
|
|
|
|
|
if mask_dropout > 0: |
|
|
masked = np.flatnonzero(mask[i]) |
|
|
num_holes = np.rint(len(masked) * mask_dropout).astype(int) |
|
|
to_drop = rng.choice(masked, num_holes, replace=False) |
|
|
mask[i, to_drop] = False |
|
|
|
|
|
return mask |
|
|
|
|
|
def _to_int_tuple(s: str): |
|
|
return tuple(map(int, s.split(","))) |
|
|
|
|
|
def get_encoder_embed(config: SpearConfig) -> nn.Module: |
|
|
|
|
|
encoder_embed = Conv2dSubsampling( |
|
|
in_channels=config.num_mel_bins, |
|
|
out_channels=_to_int_tuple(config.encoder_dim)[0], |
|
|
) |
|
|
return encoder_embed |
|
|
|
|
|
def get_encoder_model(config: SpearConfig) -> nn.Module: |
|
|
|
|
|
encoder = Zipformer2( |
|
|
output_downsampling_factor=config.output_downsampling_factor, |
|
|
downsampling_factor=_to_int_tuple(config.downsampling_factor), |
|
|
num_encoder_layers=_to_int_tuple(config.num_encoder_layers), |
|
|
encoder_dim=_to_int_tuple(config.encoder_dim), |
|
|
encoder_unmasked_dim=_to_int_tuple(config.encoder_unmasked_dim), |
|
|
query_head_dim=_to_int_tuple("32"), |
|
|
pos_head_dim=_to_int_tuple("4"), |
|
|
value_head_dim=_to_int_tuple("12"), |
|
|
pos_dim=config.pos_dim, |
|
|
num_heads=_to_int_tuple(config.num_heads), |
|
|
feedforward_dim=_to_int_tuple(config.feedforward_dim), |
|
|
cnn_module_kernel=_to_int_tuple(config.cnn_module_kernel), |
|
|
warmup_batches=4000.0, |
|
|
causal=config.causal, |
|
|
chunk_size=config.chunk_size, |
|
|
left_context_frames=config.left_context_frames, |
|
|
) |
|
|
return encoder |
|
|
|
|
|
|
|
|
def _test_w2v2_channel_mask(): |
|
|
x = torch.ones(100, 1000, 128) |
|
|
B, T, C = x.shape |
|
|
|
|
|
configs = [(0.25, 15), (0.25, 20), (0.5, 15),] |
|
|
|
|
|
for config in configs: |
|
|
mask_channel_prob, mask_channel_length = config |
|
|
ratios = [] |
|
|
for i in range(20): |
|
|
mask_channel_indices = compute_mask_indices( |
|
|
(B, C), |
|
|
None, |
|
|
mask_channel_prob, |
|
|
mask_channel_length, |
|
|
"static", |
|
|
0.0, |
|
|
no_overlap=False, |
|
|
min_space=1, |
|
|
require_same_masks=False, |
|
|
) |
|
|
mask_channel_indices = ( |
|
|
torch.from_numpy(mask_channel_indices) |
|
|
.to(x.device) |
|
|
.unsqueeze(1) |
|
|
.expand(-1, T, -1) |
|
|
) |
|
|
ratio = mask_channel_indices.sum() / mask_channel_indices.numel() |
|
|
ratios.append(ratio) |
|
|
avg_ratio = sum(ratios) / len(ratios) |
|
|
print(f"Current config: mask_channel_prob = {mask_channel_prob}, mask_channel_length = {mask_channel_length}") |
|
|
print(f"Averaged masking ratio: {avg_ratio}") |
|
|
|
|
|
def _test_w2v2_mask(): |
|
|
x = torch.ones(100, 1000, 128) |
|
|
B, T, C = x.shape |
|
|
|
|
|
mask_prob = 0.65 |
|
|
mask_length = 10 |
|
|
|
|
|
|
|
|
configs = [] |
|
|
for i in range(6): |
|
|
p = 0.05 + (i+1) * 0.1 |
|
|
for l in [10, 20, 30, 40]: |
|
|
configs.append((p, l)) |
|
|
configs = [(0.65, 10), (0.02, 40), (0.05, 40), (0.1, 40)] |
|
|
for config in configs: |
|
|
mask_prob, mask_length = config |
|
|
ratios = [] |
|
|
for i in range(20): |
|
|
mask_indices = compute_mask_indices( |
|
|
(B, T), |
|
|
None, |
|
|
mask_prob, |
|
|
mask_length, |
|
|
mask_type="static", |
|
|
mask_other=0.0, |
|
|
min_masks=2, |
|
|
no_overlap=False, |
|
|
min_space=1, |
|
|
require_same_masks=False, |
|
|
) |
|
|
mask_indices = torch.from_numpy(mask_indices) |
|
|
ratio = mask_indices.sum() / mask_indices.numel() |
|
|
ratios.append(ratio) |
|
|
avg_ratio = sum(ratios) / len(ratios) |
|
|
print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") |
|
|
print(f"Averaged masking ratio: {avg_ratio}") |
|
|
|
|
|
def _test_custom_mask(): |
|
|
x = torch.ones(100, 1000, 128) |
|
|
B, T, C = x.shape |
|
|
|
|
|
configs = [(0.5, 20), (0.2, 20), (0.3, 20), (0.4, 20), (0.5, 20)] |
|
|
for config in configs: |
|
|
mask_prob, mask_length = config |
|
|
ratios = [] |
|
|
for i in range(20): |
|
|
all_possible_mask_lengths = [mask_length + i * 2 for i in range(-5, 6)] |
|
|
mask_length = random.sample(all_possible_mask_lengths, 1)[0] |
|
|
assert mask_length > 0, f"Sampled mask_length smaller than 0, {mask_length}" |
|
|
|
|
|
mask_indices = compute_mask_indices_block( |
|
|
shape=(B, T), |
|
|
padding_mask=None, |
|
|
mask_prob=mask_prob, |
|
|
mask_length=mask_length, |
|
|
min_masks=2, |
|
|
) |
|
|
ratio = mask_indices.sum() / mask_indices.numel() |
|
|
ratios.append(ratio) |
|
|
avg_ratio = sum(ratios) / len(ratios) |
|
|
print(f"Current config: mask_prob = {mask_prob}, mask_length = {mask_length}") |
|
|
print(f"Averaged masking ratio: {avg_ratio}") |
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
_test_w2v2_channel_mask() |
|
|
_test_w2v2_mask() |
|
|
_test_custom_mask() |