bweng commited on
Commit
1ac1652
·
verified ·
1 Parent(s): daaf844

Upload 3 files

Browse files

Samples for how to use the models

Files changed (4) hide show
  1. .gitattributes +2 -0
  2. TS3003b_mix_headset.wav +3 -0
  3. first_10_seconds.wav +3 -0
  4. main.swift +449 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ first_10_seconds.wav filter=lfs diff=lfs merge=lfs -text
37
+ TS3003b_mix_headset.wav filter=lfs diff=lfs merge=lfs -text
TS3003b_mix_headset.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c94f3a09ab747caa7714efe8852f5ff37d36cf272b75709344991df1aa266ca
3
+ size 70729772
first_10_seconds.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:497c9204618be272b312cc04d2f21ad7a2dade87581e77efcabede1cbe11582b
3
+ size 320044
main.swift ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Accelerate
2
+ import AVFoundation
3
+ import CoreML
4
+ import Foundation
5
+
6
+ struct Segment: Hashable {
7
+ let start: Double
8
+ let end: Double
9
+ }
10
+
11
+ struct SlidingWindow {
12
+ var start: Double
13
+ var duration: Double
14
+ var step: Double
15
+
16
+ func time(forFrame index: Int) -> Double {
17
+ return start + Double(index) * step
18
+ }
19
+
20
+ func segment(forFrame index: Int) -> Segment {
21
+ let s = time(forFrame: index)
22
+ return Segment(start: s, end: s + duration)
23
+ }
24
+ }
25
+
26
+ struct SlidingWindowFeature {
27
+ var data: [[[Float]]] // (1, 589, 3)
28
+ var slidingWindow: SlidingWindow
29
+ }
30
+
31
+ var speakerDB: [String: [Float]] = [:] // Global speaker database
32
+ let threshold: Float = 0.7 // Distance threshold
33
+
34
+ func cosineDistance(_ x: [Float], _ y: [Float]) -> Float {
35
+ precondition(x.count == y.count, "Vectors must be same size")
36
+ let dot = zip(x, y).map(*).reduce(0, +)
37
+ let normX = sqrt(x.map { $0 * $0 }.reduce(0, +))
38
+ let normY = sqrt(y.map { $0 * $0 }.reduce(0, +))
39
+ return 1.0 - (dot / (normX * normY + 1e-6))
40
+ }
41
+
42
+ func updateSpeakerDB(_ speaker: String, _ newEmbedding: [Float], alpha: Float = 0.9) {
43
+ guard var oldEmbedding = speakerDB[speaker] else { return }
44
+ for i in 0..<oldEmbedding.count {
45
+ oldEmbedding[i] = alpha * oldEmbedding[i] + (1 - alpha) * newEmbedding[i]
46
+ }
47
+ speakerDB[speaker] = oldEmbedding
48
+ }
49
+
50
+ func assignSpeaker(embedding: [Float], threshold: Float = 0.7) -> String {
51
+ if speakerDB.isEmpty {
52
+ let speaker = "Speaker 1"
53
+ speakerDB[speaker] = embedding
54
+ return speaker
55
+ }
56
+
57
+ var minDistance: Float = Float.greatestFiniteMagnitude
58
+ var identifiedSpeaker: String? = nil
59
+
60
+ for (speaker, refEmbedding) in speakerDB {
61
+ let distance = cosineDistance(embedding, refEmbedding)
62
+ if distance < minDistance {
63
+ minDistance = distance
64
+ identifiedSpeaker = speaker
65
+ }
66
+ }
67
+
68
+ if let bestSpeaker = identifiedSpeaker {
69
+ if minDistance > threshold {
70
+ let newSpeaker = "Speaker \(speakerDB.count + 1)"
71
+ speakerDB[newSpeaker] = embedding
72
+ return newSpeaker
73
+ } else {
74
+ updateSpeakerDB(bestSpeaker, embedding)
75
+ return bestSpeaker
76
+ }
77
+ }
78
+
79
+ return "Unknown"
80
+ }
81
+
82
+ func getAnnotation(annotation: inout [Segment: String],
83
+ speakerMapping: [Int: Int],
84
+ binarizedSegments: [[[Float]]],
85
+ slidingWindow: SlidingWindow) {
86
+
87
+ let segmentation = binarizedSegments[0] // shape: [589][3]
88
+ let numFrames = segmentation.count
89
+
90
+ // Step 1: argmax to get dominant speaker per frame
91
+ var frameSpeakers: [Int] = []
92
+ for frame in segmentation {
93
+ if let maxIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) {
94
+ frameSpeakers.append(maxIdx)
95
+ } else {
96
+ frameSpeakers.append(0) // fallback
97
+ }
98
+ }
99
+
100
+ // Step 2: group contiguous same-speaker segments
101
+ var currentSpeaker = frameSpeakers[0]
102
+ var startFrame = 0
103
+
104
+ for i in 1..<numFrames {
105
+ if frameSpeakers[i] != currentSpeaker {
106
+ let startTime = slidingWindow.time(forFrame: startFrame)
107
+ let endTime = slidingWindow.time(forFrame: i)
108
+
109
+ let segment = Segment(start: startTime, end: endTime)
110
+ if let mappedSpeaker = speakerMapping[currentSpeaker] {
111
+ annotation[segment] = "Speaker \(mappedSpeaker)"
112
+ }
113
+ currentSpeaker = frameSpeakers[i]
114
+ startFrame = i
115
+ }
116
+ }
117
+
118
+ // Final segment
119
+ let finalStart = slidingWindow.time(forFrame: startFrame)
120
+ let finalEnd = slidingWindow.segment(forFrame: numFrames - 1).end
121
+ let finalSegment = Segment(start: finalStart, end: finalEnd)
122
+ if let mappedSpeaker = speakerMapping[currentSpeaker] {
123
+ annotation[finalSegment] = "Speaker \(mappedSpeaker)"
124
+ }
125
+ }
126
+
127
+
128
+ func getEmbedding(audioChunk: [Float],
129
+ binarizedSegments _: [[[Float]]],
130
+ slidingWindowSegments: SlidingWindowFeature,
131
+ chunkSize: Int = 10 * 16000,
132
+ embeddingModel: MLModel) -> MLMultiArray?
133
+ {
134
+ // 1. Create audio_tensor of shape (1, 1, chunkSize)
135
+ let audioTensor = audioChunk
136
+
137
+ let batchSize = slidingWindowSegments.data.count
138
+ let numFrames = slidingWindowSegments.data[0].count
139
+ let numSpeakers = slidingWindowSegments.data[0][0].count
140
+
141
+ // 2. Compute clean_frames = 1.0 where active speakers < 2
142
+ var cleanFrames = Array(repeating: Array(repeating: 0.0 as Float, count: 1), count: numFrames)
143
+
144
+ for f in 0 ..< numFrames {
145
+ let frame = slidingWindowSegments.data[0][f]
146
+ let speakerSum = frame.reduce(0, +)
147
+ cleanFrames[f][0] = (speakerSum < 2.0) ? 1.0 : 0.0
148
+ }
149
+
150
+ // 3. Multiply slidingWindowSegments.data by cleanFrames
151
+ var cleanSegmentData = Array(
152
+ repeating: Array(repeating: Array(repeating: 0.0 as Float, count: numSpeakers), count: numFrames),
153
+ count: 1
154
+ )
155
+
156
+ for f in 0 ..< numFrames {
157
+ for s in 0 ..< numSpeakers {
158
+ cleanSegmentData[0][f][s] = slidingWindowSegments.data[0][f][s] * cleanFrames[f][0]
159
+ }
160
+ }
161
+
162
+ // 4. Flatten audio tensor to shape (3, 160000)
163
+ var audioBatch: [[Float]] = []
164
+ for _ in 0 ..< 3 {
165
+ audioBatch.append(audioTensor)
166
+ }
167
+
168
+ // 5. Transpose mask shape to (3, 589)
169
+ var cleanMasks: [[Float]] = Array(repeating: Array(repeating: 0.0, count: numFrames), count: numSpeakers)
170
+
171
+ for s in 0 ..< numSpeakers {
172
+ for f in 0 ..< numFrames {
173
+ cleanMasks[s][f] = cleanSegmentData[0][f][s]
174
+ }
175
+ }
176
+
177
+ // 6. Prepare MLMultiArray inputs
178
+ guard let waveformArray = try? MLMultiArray(shape: [3, chunkSize] as [NSNumber], dataType: .float32),
179
+ let maskArray = try? MLMultiArray(shape: [3, numFrames] as [NSNumber], dataType: .float32)
180
+ else {
181
+ print("Failed to allocate MLMultiArray")
182
+ return nil
183
+ }
184
+
185
+ // Fill waveform
186
+ for s in 0 ..< 3 {
187
+ for i in 0 ..< chunkSize {
188
+ waveformArray[s * chunkSize + i] = NSNumber(value: audioBatch[s][i])
189
+ }
190
+ }
191
+
192
+ // Fill mask
193
+ for s in 0 ..< 3 {
194
+ for f in 0 ..< numFrames {
195
+ maskArray[s * numFrames + f] = NSNumber(value: cleanMasks[s][f])
196
+ }
197
+ }
198
+
199
+ // 7. Run model
200
+ let inputs: [String: Any] = [
201
+ "waveform": waveformArray,
202
+ "mask": maskArray,
203
+ ]
204
+
205
+ guard let output = try? embeddingModel.prediction(from: MLDictionaryFeatureProvider(dictionary: inputs)) else {
206
+ print("Embedding model prediction failed")
207
+ return nil
208
+ }
209
+
210
+ return output.featureValue(for: "embedding")?.multiArrayValue
211
+ }
212
+
213
+ func loadAudioSamples(from url: URL, expectedSampleRate: Double = 16000.0) throws -> [Float] {
214
+ let file = try AVAudioFile(forReading: url)
215
+ let format = AVAudioFormat(commonFormat: .pcmFormatFloat32,
216
+ sampleRate: expectedSampleRate,
217
+ channels: 1,
218
+ interleaved: false)!
219
+
220
+ let engine = AVAudioEngine()
221
+ let player = AVAudioPlayerNode()
222
+ engine.attach(player)
223
+
224
+ let converter = AVAudioConverter(from: file.processingFormat, to: format)!
225
+ let frameCapacity = AVAudioFrameCount(file.length)
226
+ let buffer = AVAudioPCMBuffer(pcmFormat: file.processingFormat, frameCapacity: frameCapacity)!
227
+ try file.read(into: buffer)
228
+
229
+ let outputBuffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: frameCapacity)!
230
+
231
+ let inputBlock: AVAudioConverterInputBlock = { _, outStatus in
232
+ outStatus.pointee = .haveData
233
+ return buffer
234
+ }
235
+
236
+ try converter.convert(to: outputBuffer, error: nil, withInputFrom: inputBlock)
237
+
238
+ guard let floatChannelData = outputBuffer.floatChannelData else {
239
+ throw NSError(domain: "Audio", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing float data"])
240
+ }
241
+
242
+ let channelData = floatChannelData[0]
243
+ let samples = Array(UnsafeBufferPointer(start: channelData, count: Int(outputBuffer.frameLength)))
244
+ return samples
245
+ }
246
+
247
+ func chunkAndRunSegmentation(samples: [Float], chunkSize: Int = 160_000, model: MLModel, embeddingModel: MLModel) throws {
248
+ let totalSamples = samples.count
249
+ let numberOfChunks = Int(ceil(Double(totalSamples) / Double(chunkSize)))
250
+ var annotations: [Segment: String] = [:]
251
+
252
+ for i in 0 ..< numberOfChunks {
253
+ let start = i * chunkSize
254
+ let end = min((i + 1) * chunkSize, totalSamples)
255
+ let chunk = Array(samples[start ..< end])
256
+
257
+ // If chunk is shorter than 10s, pad with zeros
258
+ var paddedChunk = chunk
259
+ if chunk.count < chunkSize {
260
+ paddedChunk += Array(repeating: 0.0, count: chunkSize - chunk.count)
261
+ }
262
+
263
+ let binarizedSegments = try getSegments(audioChunk: paddedChunk, model: model)
264
+ let frames = SlidingWindow(start: Double(i) * 10.0, duration: 0.0619375, step: 0.016875)
265
+ let slidingFeature = SlidingWindowFeature(data: binarizedSegments, slidingWindow: frames)
266
+ if let embeddings = getEmbedding(audioChunk: paddedChunk,
267
+ binarizedSegments: binarizedSegments,
268
+ slidingWindowSegments: slidingFeature,
269
+ embeddingModel: embeddingModel)
270
+ {
271
+ print("Embeddings shape: \(embeddings.shape.map { $0.intValue })")
272
+
273
+ let shape = embeddings.shape.map { $0.intValue } // [3, 256]
274
+ let numSpeakers = shape[0]
275
+ let embeddingDim = shape[1]
276
+ let strides = embeddings.strides.map { $0.intValue }
277
+
278
+ var speakerSums = [Float](repeating: 0.0, count: numSpeakers)
279
+
280
+ for s in 0 ..< numSpeakers {
281
+ for d in 0 ..< embeddingDim {
282
+ let index = s * strides[0] + d * strides[1]
283
+ speakerSums[s] += embeddings[index].floatValue
284
+ }
285
+ }
286
+
287
+ print("Sum along axis 1 (per speaker): \(speakerSums)")
288
+
289
+ // Step 3: Assign speaker label to each embedding
290
+ var speakerLabels = [String]()
291
+ for s in 0..<numSpeakers {
292
+ var embeddingVec = [Float](repeating: 0.0, count: embeddingDim)
293
+ for d in 0..<embeddingDim {
294
+ let index = s * strides[0] + d * strides[1]
295
+ embeddingVec[d] = embeddings[index].floatValue
296
+ }
297
+ let label = assignSpeaker(embedding: embeddingVec)
298
+ speakerLabels.append(label)
299
+ }
300
+
301
+ print("Chunk \(i + 1): Assigned Speakers: \(speakerLabels)")
302
+
303
+ // Step 4: Update annotations
304
+ // Map speaker index 0,1,2 → assigned speakerLabels
305
+ var labelMapping: [Int: Int] = [:]
306
+ for (idx, label) in speakerLabels.enumerated() {
307
+ if let spkNum = Int(label.components(separatedBy: " ").last ?? "") {
308
+ labelMapping[idx] = spkNum
309
+ }
310
+ }
311
+
312
+ getAnnotation(annotation: &annotations,
313
+ speakerMapping: labelMapping,
314
+ binarizedSegments: binarizedSegments,
315
+ slidingWindow: frames)
316
+
317
+ print("Chunk \(i + 1) → Segments shape: \(binarizedSegments[0].count) frames")
318
+ }
319
+ }
320
+
321
+ // Final result
322
+ print("\n=== Final Annotations ===")
323
+ for (segment, speaker) in annotations.sorted(by: { $0.key.start < $1.key.start }) {
324
+ print("\(speaker): \(segment.start) - \(segment.end)")
325
+ }
326
+
327
+ }
328
+
329
+ func powersetConversion(_ segments: [[[Float]]]) -> [[[Float]]] {
330
+ let powerset: [[Int]] = [
331
+ [], // 0
332
+ [0], // 1
333
+ [1], // 2
334
+ [2], // 3
335
+ [0, 1], // 4
336
+ [0, 2], // 5
337
+ [1, 2], // 6
338
+ ]
339
+
340
+ let batchSize = segments.count
341
+ let numFrames = segments[0].count
342
+ let numCombos = segments[0][0].count // 7
343
+
344
+ let numSpeakers = 3
345
+ var binarized = Array(
346
+ repeating: Array(
347
+ repeating: Array(repeating: 0.0 as Float, count: numSpeakers),
348
+ count: numFrames
349
+ ),
350
+ count: batchSize
351
+ )
352
+
353
+ for b in 0 ..< batchSize {
354
+ for f in 0 ..< numFrames {
355
+ let frame = segments[b][f]
356
+
357
+ // Find index of max value in this frame
358
+ guard let bestIdx = frame.indices.max(by: { frame[$0] < frame[$1] }) else {
359
+ continue
360
+ }
361
+
362
+ // Mark the corresponding speakers as active
363
+ for speaker in powerset[bestIdx] {
364
+ binarized[b][f][speaker] = 1.0
365
+ }
366
+ }
367
+ }
368
+
369
+ return binarized
370
+ }
371
+
372
+ func getSegments(audioChunk: [Float], sampleRate _: Int = 16000, chunkSize: Int = 160_000, model: MLModel) throws -> [[[Float]]] {
373
+ // Ensure correct shape: (1, 1, chunk_size)
374
+ let audioArray = try MLMultiArray(shape: [1, 1, NSNumber(value: chunkSize)], dataType: .float32)
375
+ for i in 0 ..< audioChunk.count {
376
+ audioArray[i] = NSNumber(value: audioChunk[i])
377
+ }
378
+
379
+ // Prepare input
380
+ let input = try MLDictionaryFeatureProvider(dictionary: ["audio": audioArray])
381
+
382
+ // Run prediction
383
+ let output = try model.prediction(from: input)
384
+
385
+ // Extract segments output: shape assumed (1, frames, 7)
386
+ guard let segmentOutput = output.featureValue(for: "segments")?.multiArrayValue else {
387
+ throw NSError(domain: "ModelOutput", code: -1, userInfo: [NSLocalizedDescriptionKey: "Missing segments output"])
388
+ }
389
+
390
+ let frames = segmentOutput.shape[1].intValue
391
+ let combinations = segmentOutput.shape[2].intValue
392
+
393
+ // Convert MLMultiArray to [[[Float]]]
394
+ var segments = Array(repeating: Array(repeating: Array(repeating: 0.0 as Float, count: combinations), count: frames), count: 1)
395
+
396
+ for f in 0 ..< frames {
397
+ for c in 0 ..< combinations {
398
+ let index = f * combinations + c
399
+ segments[0][f][c] = segmentOutput[index].floatValue
400
+ }
401
+ }
402
+
403
+ // Apply powerset conversion
404
+ let binarizedSegments = powersetConversion(segments)
405
+
406
+ // Assume segments shape is (1, 589, 3)
407
+ guard binarizedSegments.count == 1 else {
408
+ fatalError("Expected batch size 1")
409
+ }
410
+
411
+ let b_frames = binarizedSegments[0]
412
+ let numSpeakers = b_frames[0].count
413
+
414
+ // Initialize sum array
415
+ var speakerSums = Array(repeating: 0.0 as Float, count: numSpeakers)
416
+
417
+ // Sum across axis 1 (frames)
418
+ for frame in b_frames {
419
+ for (i, value) in frame.enumerated() {
420
+ speakerSums[i] += value
421
+ }
422
+ }
423
+
424
+ print("Sum across axis 1 (frames): \(speakerSums)")
425
+
426
+ return binarizedSegments
427
+ }
428
+
429
+ func loadModel(from path: String) throws -> MLModel {
430
+ let url = URL(fileURLWithPath: path)
431
+ let model = try MLModel(contentsOf: url)
432
+ return model
433
+ }
434
+
435
+ do {
436
+ let modelPath = "./pyannote_segmentation.mlmodelc"
437
+ let embeddingPath = "./wespeaker.mlmodelc"
438
+ let model = try loadModel(from: modelPath)
439
+ let embeddingModel = try loadModel(from: embeddingPath)
440
+ print("Model loaded successfully.")
441
+
442
+ // let audioPath = "./first_10_seconds.wav"
443
+ let audioPath = "./TS3003b_mix_headset.wav"
444
+
445
+ let audioSamples = try loadAudioSamples(from: URL(fileURLWithPath: audioPath))
446
+ try chunkAndRunSegmentation(samples: audioSamples, model: model, embeddingModel: embeddingModel)
447
+ } catch {
448
+ print("Error: \(error)")
449
+ }