wh1tet3a commited on
Commit
7afafba
·
1 Parent(s): e39cd34

add spectra_0

Browse files
Files changed (2) hide show
  1. model.py +261 -0
  2. model.safetensors +3 -0
model.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, torch
2
+ import torch.nn as nn
3
+ from transformers import Wav2Vec2Model
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+
6
+
7
+ class SEModule(nn.Module):
8
+ def __init__(self, channels, bottleneck=128):
9
+ super(SEModule, self).__init__()
10
+ self.se = nn.Sequential(
11
+ nn.AdaptiveAvgPool1d(1),
12
+ nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
13
+ nn.ReLU(),
14
+ # nn.BatchNorm1d(bottleneck), # I remove this layer
15
+ nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
16
+ nn.Sigmoid(),
17
+ )
18
+
19
+ def forward(self, input):
20
+ x = self.se(input)
21
+ return input * x
22
+
23
+
24
+ class Bottle2neck(nn.Module):
25
+ def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
26
+ super(Bottle2neck, self).__init__()
27
+ width = int(math.floor(planes / scale))
28
+ self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
29
+ self.bn1 = nn.BatchNorm1d(width * scale)
30
+ self.nums = scale - 1
31
+ convs = []
32
+ bns = []
33
+ num_pad = math.floor(kernel_size / 2) * dilation
34
+ for i in range(self.nums):
35
+ convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
36
+ bns.append(nn.BatchNorm1d(width))
37
+ self.convs = nn.ModuleList(convs)
38
+ self.bns = nn.ModuleList(bns)
39
+ self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
40
+ self.bn3 = nn.BatchNorm1d(planes)
41
+ self.relu = nn.ReLU()
42
+ self.width = width
43
+ self.se = SEModule(planes)
44
+
45
+ def forward(self, x):
46
+ residual = x
47
+ out = self.conv1(x)
48
+ out = self.relu(out)
49
+ out = self.bn1(out)
50
+
51
+ spx = torch.split(out, self.width, 1)
52
+ for i in range(self.nums):
53
+ if i == 0:
54
+ sp = spx[i]
55
+ else:
56
+ sp = sp + spx[i]
57
+ sp = self.convs[i](sp)
58
+ sp = self.relu(sp)
59
+ sp = self.bns[i](sp)
60
+ if i == 0:
61
+ out = sp
62
+ else:
63
+ out = torch.cat((out, sp), 1)
64
+ out = torch.cat((out, spx[self.nums]), 1)
65
+
66
+ out = self.conv3(out)
67
+ out = self.relu(out)
68
+ out = self.bn3(out)
69
+
70
+ out = self.se(out)
71
+ out += residual
72
+ return out
73
+
74
+
75
+ class ECAPA_TDNN(nn.Module):
76
+
77
+ def __init__(self, C):
78
+
79
+ super(ECAPA_TDNN, self).__init__()
80
+ self.conv1 = nn.Conv1d(128, C, kernel_size=5, stride=1, padding=2)
81
+ self.relu = nn.ReLU()
82
+ self.bn1 = nn.BatchNorm1d(C)
83
+ self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
84
+ self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
85
+ self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
86
+ self.layer4 = Bottle2neck(C, C, kernel_size=3, dilation=5, scale=8)
87
+ # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
88
+ self.layer5 = nn.Conv1d(4 * C, 1536, kernel_size=1)
89
+ self.attention = nn.Sequential(
90
+ nn.Conv1d(4608, 256, kernel_size=1),
91
+ nn.ReLU(),
92
+ nn.BatchNorm1d(256),
93
+ nn.Tanh(), # I add this layer
94
+ nn.Conv1d(256, 1536, kernel_size=1),
95
+ nn.Softmax(dim=2),
96
+ )
97
+ self.bn5 = nn.BatchNorm1d(3072)
98
+ self.fc6 = nn.Linear(3072, 2)
99
+
100
+ def forward(self, x):
101
+ x = x.transpose(1, 2)
102
+ x = self.conv1(x)
103
+ x = self.relu(x)
104
+ x = self.bn1(x)
105
+
106
+ x1 = self.layer1(x)
107
+ x2 = self.layer2(x + x1)
108
+ x3 = self.layer3(x + x1 + x2)
109
+ x4 = self.layer4(x + x1 + x2 + x3)
110
+
111
+ x = self.layer5(torch.cat((x1, x2, x3, x4), dim=1))
112
+ x = self.relu(x)
113
+
114
+ t = x.size()[-1]
115
+
116
+ global_x = torch.cat((x, torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), torch.sqrt(torch.var(x, dim=2, keepdim=True).clamp(min=1e-4)).repeat(1, 1, t)), dim=1)
117
+
118
+ w = self.attention(global_x)
119
+
120
+ mu = torch.sum(x * w, dim=2)
121
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu ** 2).clamp(min=1e-4))
122
+
123
+ x = torch.cat((mu, sg), 1)
124
+ x = self.bn5(x)
125
+ x = self.fc6(x)
126
+
127
+ return x
128
+
129
+
130
+ class Wav2Vec2Encoder(nn.Module):
131
+ """SSL encoder based on Hugging Face's Wav2Vec2 model."""
132
+
133
+ def __init__(self,
134
+ model_name_or_path: str = "facebook/wav2vec2-base-960h",
135
+ output_attentions: bool = False,
136
+ output_hidden_states: bool = False,
137
+ normalize_waveform: bool = False):
138
+ """Initialize the Wav2Vec2 encoder.
139
+
140
+ Args:
141
+ model_name_or_path: HuggingFace model name or path to local model.
142
+ output_attentions: Whether to output attentions.
143
+ output_hidden_states: Whether to output hidden states.
144
+ normalize_waveform: Whether to normalize the waveform input.
145
+ """
146
+ super().__init__()
147
+
148
+ self.model_name_or_path = model_name_or_path
149
+ self.output_attentions = output_attentions
150
+ self.output_hidden_states = output_hidden_states
151
+ self.normalize_waveform = normalize_waveform
152
+
153
+ # Load Wav2Vec2 model
154
+ self.model = Wav2Vec2Model.from_pretrained(
155
+ model_name_or_path,
156
+ gradient_checkpointing=False)
157
+ self.model.config.apply_spec_augment = False
158
+ self.model.masked_spec_embed = None
159
+
160
+
161
+ def forward(self, x):
162
+ """Forward pass through the Wav2Vec2 encoder.
163
+
164
+ Args:
165
+ x: Input tensor of shape (batch_size, sequence_length, channels)
166
+
167
+ Returns:
168
+ Extracted features of shape (batch_size, sequence_length, 1024)
169
+ """
170
+ # Handle shape: convert (batch_size, sequence_length, channels) to (batch_size, sequence_length)
171
+ if x.ndim == 3:
172
+ x = x.squeeze(-1) # Remove channel dimension if present
173
+
174
+ # Normalize input if specified
175
+ if self.normalize_waveform:
176
+ x = x / (torch.max(torch.abs(x), dim=1, keepdim=True)[0] + 1e-8)
177
+
178
+ # Wav2Vec2 forward pass
179
+ outputs = self.model(
180
+ x,
181
+ output_attentions=self.output_attentions,
182
+ output_hidden_states=self.output_hidden_states,
183
+ return_dict=True
184
+ )
185
+
186
+ # Extract last hidden state
187
+ last_hidden_state = outputs.last_hidden_state
188
+
189
+ return last_hidden_state
190
+
191
+
192
+ class MLPBridge(nn.Module):
193
+
194
+ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None,
195
+ dropout: float = 0.1, activation: str = nn.ReLU, n_layers: int = 1):
196
+ """Initialize the MLP bridge.
197
+
198
+ Args:
199
+ input_dim: The input dimension from the SSL encoder.
200
+ output_dim: The output dimension for the model.
201
+ hidden_dim: Hidden dimension size. If None, use the average of input and output dims.
202
+ dropout: Dropout probability to apply between layers.
203
+ activation: Activation function to use
204
+ n_layers: Number of MLP layers (repeats of Linear+Activation+Dropout blocks).
205
+ """
206
+ super().__init__()
207
+
208
+ if hidden_dim is None:
209
+ hidden_dim = (input_dim + output_dim) // 2
210
+
211
+ self.input_dim = input_dim
212
+ self.output_dim = output_dim
213
+ self.hidden_dim = hidden_dim
214
+ self.n_layers = n_layers
215
+
216
+ assert hasattr(activation, 'forward') and callable(getattr(activation, 'forward', None)), "Activation class must have a callable forward() method."
217
+ act_fn = activation
218
+
219
+ layers = []
220
+ for i in range(n_layers):
221
+ in_dim = input_dim if i == 0 else hidden_dim
222
+ out_dim = hidden_dim
223
+ layers.append(nn.Linear(in_dim, out_dim))
224
+ layers.append(act_fn)
225
+ layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity())
226
+ # Final output layer
227
+ layers.append(nn.Linear(hidden_dim, output_dim))
228
+ layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity())
229
+
230
+ self.mlp = nn.Sequential(*layers)
231
+
232
+ def forward(self, x):
233
+ """Forward pass through the bridge.
234
+
235
+ Args:
236
+ x: The input tensor from the SSL encoder.
237
+
238
+ Returns:
239
+ The transformed tensor.
240
+ """
241
+ return self.mlp(x)
242
+
243
+
244
+ class Spectra0Model(nn.Module, PyTorchModelHubMixin):
245
+ def __init__(self, **kwargs):
246
+ super().__init__()
247
+ self.ssl_encoder = Wav2Vec2Encoder("facebook/wav2vec2-xls-r-300m")
248
+ self.bridge = MLPBridge(1024, 128, hidden_dim=128, activation=nn.SELU())
249
+ self.ecapa_tdnn = ECAPA_TDNN(128)
250
+
251
+ def forward(self, x):
252
+ x = self.ssl_encoder(x)
253
+ x = self.bridge(x)
254
+ x = self.ecapa_tdnn(x)
255
+ return x
256
+
257
+ @torch.inference_mode()
258
+ def classify(self, x, threshold: float = 0.399):
259
+ x = self.forward(x)[:, 1]
260
+ x = (x > threshold).float()
261
+ return x.item()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:830d05e5ff3fe6860858fdfee2bfdf61e0287bbf26892731e846ff7bbef5546b
3
+ size 1273453560