ykhrustalev commited on
Commit
23b5df6
·
unverified ·
1 Parent(s): 1594347

correct the input

Browse files
Files changed (1) hide show
  1. audio-model.js +146 -119
audio-model.js CHANGED
@@ -397,7 +397,7 @@ export class AudioModel {
397
  };
398
 
399
  // Helper to load ONNX model with external data
400
- const loadOnnxWithExternalData = async (name, progress, quantSuffix = null, epOverride = null, extraOptions = {}) => {
401
  const suffix = quantSuffix ? `_${quantSuffix}` : '';
402
  const fileName = `${name}${suffix}`;
403
  report('loading', progress, `${fileName}.onnx`);
@@ -405,10 +405,9 @@ export class AudioModel {
405
  const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
406
  const fetchOptions = { mode: 'cors', credentials: 'omit' };
407
 
408
- const ep = epOverride || executionProviders;
409
- console.log(`Loading ${fileName} (EP: ${JSON.stringify(ep)})...`);
410
 
411
- const sessionOptions = { executionProviders: ep, ...extraOptions };
412
 
413
  const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
414
  if (!onnxResponse.ok) {
@@ -478,7 +477,7 @@ export class AudioModel {
478
  }
479
  return { preferredOutputLocation: loc };
480
  })() : {};
481
- this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder, null, decoderOpts);
482
 
483
  // Load embed_tokens weight for text embedding lookup
484
  report('loading', 30, 'embed_tokens');
@@ -504,15 +503,13 @@ export class AudioModel {
504
  console.warn('Audio detokenizer not available:', e);
505
  }
506
 
507
- // Load vocoder/depthformer (for TTS) — per-step model (8 calls per frame)
508
  // On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps
509
  try {
510
  const vocoderOpts = device === 'webgpu'
511
  ? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer' } }
512
  : {};
513
- this.vocoderSession = await loadOnnxWithExternalData(
514
- 'vocoder_depthformer', 95, quantConfig.vocoder, null, vocoderOpts,
515
- );
516
  } catch (e) {
517
  console.warn('Vocoder not available:', e);
518
  }
@@ -954,103 +951,129 @@ export class AudioModel {
954
  return '[Text generation requires full embedding support - model loaded successfully]';
955
  }
956
 
 
 
 
957
  _initVocoderCache() {
958
  if (this._vocoderCache) return;
959
 
960
- const vocabSize = 2049;
 
 
 
 
961
  const stepIdxData = new BigInt64Array(1);
962
  const prevTokenData = new BigInt64Array(1);
963
- const seqlensKData = new Int32Array(1);
964
- const totalSeqLenData = new Int32Array(1);
965
 
 
966
  this._vocoderCache = {
 
967
  stepIdxData,
968
  prevTokenData,
969
- seqlensKData,
970
- totalSeqLenData,
971
  stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
972
  prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
973
- seqlensKTensor: new ort.Tensor('int32', seqlensKData, [1]),
974
- totalSeqLenTensor: new ort.Tensor('int32', totalSeqLenData, []),
975
- emptyData: new Float32Array(0),
976
- // Pre-allocated sampling arrays
977
- scaledLogits: new Float32Array(vocabSize),
978
- indices: new Uint16Array(vocabSize),
979
- probs: new Float32Array(64),
980
  };
 
 
 
 
 
981
  }
982
 
983
  /**
984
- * Sample audio codes using per-step depthformer (8 session.run calls).
985
- * Uses GroupQueryAttention with BNSH KV cache format.
986
  * @param {Float32Array} hiddenState - [hidden_size] hidden state
987
  * @param {number} temperature - Sampling temperature
988
- * @param {number} topK - Top-k sampling (0 = greedy)
989
  * @returns {number[]} - 8 codebook values
990
  */
991
  async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
 
 
 
 
 
992
  this._initVocoderCache();
993
  const cache = this._vocoderCache;
994
 
995
  const numCodebooks = 8;
996
  const numLayers = 6;
997
- const numKVHeads = 8;
998
  const headDim = 32;
999
- const vocabSize = 2049;
1000
-
1001
- const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
1002
- // BNSH format: [layers, batch, heads, seq_len, head_dim]
1003
- let pastKeys = new ort.Tensor('float32', cache.emptyData, [numLayers, 1, numKVHeads, 0, headDim]);
1004
- let pastValues = new ort.Tensor('float32', cache.emptyData, [numLayers, 1, numKVHeads, 0, headDim]);
1005
- let depthSlices = new ort.Tensor('float32', new Float32Array(numCodebooks * 1024), [1, numCodebooks, 1024]);
1006
 
1007
  const codes = [];
1008
  let prevToken = 0;
1009
 
1010
- for (let step = 0; step < numCodebooks; step++) {
1011
- cache.stepIdxData[0] = BigInt(step);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  cache.prevTokenData[0] = BigInt(prevToken);
1013
- cache.seqlensKData[0] = step;
1014
- cache.totalSeqLenData[0] = step + 1;
1015
 
1016
- const outputs = await this.vocoderSession.run({
1017
  hidden_states: hiddenTensor,
1018
- depth_slices_in: depthSlices,
1019
  step_idx: cache.stepIdxTensor,
1020
  prev_token: cache.prevTokenTensor,
1021
  past_keys: pastKeys,
1022
  past_values: pastValues,
1023
- seqlens_k: cache.seqlensKTensor,
1024
- total_seq_len: cache.totalSeqLenTensor,
1025
- });
1026
-
1027
- if (step === 0) {
1028
- depthSlices = outputs.depth_slices;
1029
- }
1030
-
1031
- pastKeys = outputs.new_keys;
1032
- pastValues = outputs.new_values;
1033
 
 
1034
  const logits = outputs.logits.data;
 
1035
 
 
1036
  let token;
1037
- if (temperature <= 0 || topK <= 1) {
 
1038
  token = 0;
1039
  let maxVal = logits[0];
1040
  for (let j = 1; j < vocabSize; j++) {
1041
- if (logits[j] > maxVal) { maxVal = logits[j]; token = j; }
 
 
 
1042
  }
1043
  } else {
 
1044
  const scaledLogits = cache.scaledLogits;
1045
  const indices = cache.indices;
1046
  const probs = cache.probs;
1047
 
 
 
1048
  for (let j = 0; j < vocabSize; j++) {
1049
  scaledLogits[j] = logits[j] / temperature;
1050
  indices[j] = j;
1051
  }
1052
 
1053
- // Partial selection sort for top-k: O(k*n) vs O(n log n) full sort
1054
  for (let j = 0; j < topK; j++) {
1055
  let maxIdx = j;
1056
  for (let k = j + 1; k < vocabSize; k++) {
@@ -1058,6 +1081,7 @@ export class AudioModel {
1058
  maxIdx = k;
1059
  }
1060
  }
 
1061
  const tmp = indices[j];
1062
  indices[j] = indices[maxIdx];
1063
  indices[maxIdx] = tmp;
@@ -1074,18 +1098,25 @@ export class AudioModel {
1074
  probs[j] /= sumExp;
1075
  }
1076
 
1077
- // Sample from cumulative distribution
1078
  const r = Math.random();
1079
  let cumsum = 0;
1080
- token = indices[topK - 1];
1081
  for (let j = 0; j < topK; j++) {
1082
  cumsum += probs[j];
1083
- if (r < cumsum) { token = indices[j]; break; }
 
 
 
1084
  }
1085
  }
1086
 
1087
  codes.push(token);
1088
  prevToken = token;
 
 
 
 
1089
  }
1090
 
1091
  return codes;
@@ -1290,15 +1321,12 @@ export class AudioModel {
1290
  throw new Error('Vocoder not loaded - required for interleaved mode');
1291
  }
1292
 
1293
- // Timing accumulators (names match liquid-audio architecture):
1294
- // - lfmText/lfmAudio: self.lfm() decoder calls (text vs audio steps)
1295
- // - depthformer: self._sample_audio_frame() — depth_linear + 8× depthformer
1296
- // - audioEmbed: self.audio_embedding(...).sum() — feedback embedding
1297
  let timeAudioEncode = 0;
1298
  let timePrefill = 0;
1299
- let timeLfmText = 0;
1300
- let timeLfmAudio = 0;
1301
- let timeDepthformer = 0;
1302
  let timeAudioEmbed = 0;
1303
 
1304
  // 1. Compute mel spectrogram and encode audio
@@ -1404,37 +1432,36 @@ export class AudioModel {
1404
 
1405
  const startTime = performance.now();
1406
 
1407
- log(`Generation loop: max ${maxNewTokens} steps, starting in TEXT mode`);
1408
-
1409
- let step = 0;
1410
- for (; step < maxNewTokens; step++) {
1411
  modalityLeft--;
1412
 
1413
  if (inAudioMode) {
1414
- // === AUDIO STEP: extract hidden_states → depthformer → 8 codebook tokens ===
1415
  const hiddenData = hiddenStates.data;
1416
  const seqLen = hiddenStates.dims[1];
1417
  const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
1418
 
1419
  tStep = performance.now();
1420
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1421
- timeDepthformer += performance.now() - tStep;
1422
 
1423
  // Switch back to text after N audio frames (if text not done)
1424
  if (modalityLeft <= 0 && !textDone) {
1425
- log(`→ AUDIO→TEXT (after ${INTERLEAVED_N_AUDIO} audio frames, ${audioCodes.length} total)`);
1426
  inAudioMode = false;
1427
  modalityLeft = INTERLEAVED_N_TEXT;
1428
  }
1429
 
1430
  // Check for end of audio - first codebook == 2048 (matching liquid-audio)
1431
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1432
- log(`→ END_OF_AUDIO at step ${step} (${audioCodes.length} frames collected)`);
 
1433
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1434
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1435
  }
1436
  inAudioMode = false;
 
1437
  } else {
 
1438
  const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
1439
  audioCodes.push(clampedFrame);
1440
 
@@ -1443,15 +1470,16 @@ export class AudioModel {
1443
  }
1444
 
1445
  if (audioCodes.length % 50 === 0) {
1446
- log(` Audio frames: ${audioCodes.length}`);
1447
  }
1448
  }
1449
 
1450
- // === FEEDBACK: embed 8 codes (summed) feed back to LFM decoder ===
1451
  tStep = performance.now();
1452
  const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
1453
  const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
1454
 
 
1455
  const summedEmbeds = await this.getAudioEmbedding(audioTokens);
1456
  timeAudioEmbed += performance.now() - tStep;
1457
 
@@ -1460,13 +1488,14 @@ export class AudioModel {
1460
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1461
  tStep = performance.now();
1462
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1463
- timeLfmAudio += performance.now() - tStep;
1464
  this.updateCache(this.cache, outputs);
1465
 
1466
  } else {
1467
- // === TEXT STEP: logits → sample text token ===
1468
  const logitsData = logits.data;
1469
  const seqLen = logits.dims[1];
 
1470
  const lastLogits = new Float32Array(this.vocabSize);
1471
  const offset = (seqLen - 1) * this.vocabSize;
1472
  for (let i = 0; i < this.vocabSize; i++) {
@@ -1476,19 +1505,18 @@ export class AudioModel {
1476
 
1477
  // Check for end of turn
1478
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1479
- log(`→ END_OF_TURN at step ${step} (${textTokens.length} text tokens, ${audioCodes.length} audio frames)`);
1480
  break;
1481
  }
1482
 
1483
  // Check for <|text_end|> token (130)
1484
  if (token === SPECIAL_TOKENS.TEXT_END) {
1485
- log(`→ TEXT_END at step ${step}: audio-only phase begins`);
1486
  textDone = true;
1487
  }
1488
 
1489
  // Switch to audio after N text tokens OR text_end
1490
  if (modalityLeft <= 0 || textDone) {
1491
- log(`→ TEXT→AUDIO${textDone ? ' (text_done)' : ''} at step ${step}`);
1492
  inAudioMode = true;
1493
  modalityLeft = INTERLEAVED_N_AUDIO;
1494
  }
@@ -1500,19 +1528,18 @@ export class AudioModel {
1500
  onToken(decodedText, token);
1501
  }
1502
 
1503
- // === FEEDBACK: embed text token → feed back to LFM decoder ===
1504
  const nextEmbeds = this.getTextEmbeddings([token]);
1505
  currentLen++;
1506
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1507
  tStep = performance.now();
1508
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1509
- timeLfmText += performance.now() - tStep;
1510
  this.updateCache(this.cache, outputs);
1511
  }
1512
  }
1513
 
1514
  // 5. Feed <|im_end|> token to close assistant turn in cache
1515
-
1516
  const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
1517
  currentLen++;
1518
  const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
@@ -1523,11 +1550,14 @@ export class AudioModel {
1523
  // Decode with skip_special_tokens to clean up special tokens like <|text_end|>
1524
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1525
 
1526
- const totalGenTime = performance.now() - startTime;
1527
-
1528
- log(`Done: ${step} steps in ${totalGenTime.toFixed(0)}ms | ${textTokens.length} text tokens, ${audioCodes.length} audio frames (~${(audioCodes.length / 75).toFixed(1)}s audio)`);
1529
- log(`Timing: mel=${timeMel.toFixed(0)}ms, audioEnc=${timeAudioEncode.toFixed(0)}ms, prefill=${timePrefill.toFixed(0)}ms, lfmText=${timeLfmText.toFixed(0)}ms, lfmAudio=${timeLfmAudio.toFixed(0)}ms, depthformer=${timeDepthformer.toFixed(0)}ms, audioEmbed=${timeAudioEmbed.toFixed(0)}ms`);
1530
- log(`Text: "${text}" | cache_seq_len=${this.cacheSeqLen}`);
 
 
 
1531
 
1532
  return { text, audioCodes };
1533
  }
@@ -1568,14 +1598,11 @@ export class AudioModel {
1568
  throw new Error('Vocoder not loaded - required for interleaved mode');
1569
  }
1570
 
1571
- // Timing accumulators (names match liquid-audio architecture):
1572
- // - lfmText/lfmAudio: self.lfm() decoder calls (text vs audio steps)
1573
- // - depthformer: self._sample_audio_frame() — depth_linear + 8× depthformer
1574
- // - audioEmbed: self.audio_embedding(...).sum() — feedback embedding
1575
  let timePrefill = 0;
1576
- let timeLfmText = 0;
1577
- let timeLfmAudio = 0;
1578
- let timeDepthformer = 0;
1579
  let timeAudioEmbed = 0;
1580
  let tStep;
1581
 
@@ -1640,18 +1667,17 @@ export class AudioModel {
1640
 
1641
  tStep = performance.now();
1642
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1643
- timeDepthformer += performance.now() - tStep;
1644
 
1645
  // Switch back to text after N audio frames (if text not done)
1646
  if (modalityLeft <= 0 && !textDone) {
1647
- log(`→ AUDIO→TEXT (after ${INTERLEAVED_N_AUDIO} audio frames, ${audioCodes.length} total)`);
1648
  inAudioMode = false;
1649
  modalityLeft = INTERLEAVED_N_TEXT;
1650
  }
1651
 
1652
  // Check for end of audio
1653
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1654
- log(`→ END_OF_AUDIO at step ${step} (${audioCodes.length} frames collected)`);
1655
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1656
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1657
  }
@@ -1665,7 +1691,7 @@ export class AudioModel {
1665
  }
1666
 
1667
  if (audioCodes.length % 50 === 0) {
1668
- log(` Audio frames: ${audioCodes.length}`);
1669
  }
1670
  }
1671
 
@@ -1681,7 +1707,7 @@ export class AudioModel {
1681
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1682
  tStep = performance.now();
1683
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1684
- timeLfmAudio += performance.now() - tStep;
1685
  this.updateCache(this.cache, outputs);
1686
 
1687
  } else {
@@ -1697,19 +1723,18 @@ export class AudioModel {
1697
 
1698
  // Check for end of turn
1699
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1700
- log(`→ END_OF_TURN at step ${step} (${textTokens.length} text tokens, ${audioCodes.length} audio frames)`);
1701
  break;
1702
  }
1703
 
1704
  // Check for <|text_end|> token
1705
  if (token === SPECIAL_TOKENS.TEXT_END) {
1706
- log(`→ TEXT_END at step ${step}: audio-only phase begins`);
1707
  textDone = true;
1708
  }
1709
 
1710
  // Switch to audio after N text tokens OR text_end
1711
  if (modalityLeft <= 0 || textDone) {
1712
- log(`→ TEXT→AUDIO${textDone ? ' (text_done)' : ''} at step ${step}`);
1713
  inAudioMode = true;
1714
  modalityLeft = INTERLEAVED_N_AUDIO;
1715
  }
@@ -1727,7 +1752,7 @@ export class AudioModel {
1727
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1728
  tStep = performance.now();
1729
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1730
- timeLfmText += performance.now() - tStep;
1731
  this.updateCache(this.cache, outputs);
1732
  }
1733
  }
@@ -1742,9 +1767,13 @@ export class AudioModel {
1742
 
1743
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1744
 
1745
- log(`Done: ${textTokens.length} text tokens, ${audioCodes.length} audio frames (~${(audioCodes.length / 75).toFixed(1)}s audio)`);
1746
- log(`Timing: prefill=${timePrefill.toFixed(0)}ms, lfmText=${timeLfmText.toFixed(0)}ms, lfmAudio=${timeLfmAudio.toFixed(0)}ms, depthformer=${timeDepthformer.toFixed(0)}ms, audioEmbed=${timeAudioEmbed.toFixed(0)}ms`);
1747
- log(`Text: "${text}" | cache_seq_len=${this.cacheSeqLen}`);
 
 
 
 
1748
 
1749
  return { text, audioCodes };
1750
  }
@@ -1874,8 +1903,7 @@ export class AudioModel {
1874
  }
1875
 
1876
  const decodeStart = performance.now();
1877
-
1878
- log(`Audio decode: ${audioCodes.length} frames → waveform`);
1879
 
1880
  // ISTFT parameters (fixed for this model)
1881
  const nFft = 1280;
@@ -1883,7 +1911,7 @@ export class AudioModel {
1883
  const winLength = 1280;
1884
  const nFftBins = nFft / 2 + 1;
1885
 
1886
- // Transpose codes [T, 8] [1, 8, T] for ONNX input
1887
  const T = audioCodes.length;
1888
  const codesTransposed = new BigInt64Array(8 * T);
1889
  for (let t = 0; t < T; t++) {
@@ -1892,18 +1920,18 @@ export class AudioModel {
1892
  }
1893
  }
1894
 
1895
- // Run detokenizer ONNX: [1, 8, T] [1, 6T, 1282]
1896
  const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
1897
  const detokStart = performance.now();
1898
  const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
1899
  const stftFeatures = detokOutputs.stft_features;
1900
- const actualT = stftFeatures.dims[1];
1901
- const detokEnd = performance.now();
1902
- log(` Detokenizer: [1,8,${T}] → [1,${actualT},1282] in ${(detokEnd - detokStart).toFixed(0)}ms`);
1903
 
1904
- // Split into magnitude + angle complex spectrogram
1905
  const stftData = stftFeatures.data;
 
1906
 
 
1907
  const complexStft = new Array(nFftBins);
1908
  for (let f = 0; f < nFftBins; f++) {
1909
  complexStft[f] = new Array(actualT);
@@ -1911,23 +1939,23 @@ export class AudioModel {
1911
  const logMag = stftData[t * 1282 + f];
1912
  const angle = stftData[t * 1282 + nFftBins + f];
1913
  const mag = Math.exp(logMag);
 
1914
  complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
1915
  }
1916
  }
1917
 
1918
- // ISTFT (inverse Short-Time Fourier Transform) → waveform
1919
- const pad = (winLength - hopLength) / 2;
1920
  const istftStart = performance.now();
1921
  const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
1922
- const istftEnd = performance.now();
1923
- log(` ISTFT: ${actualT} frames → ${waveform.length} samples in ${(istftEnd - istftStart).toFixed(0)}ms`);
1924
 
1925
- // Find max/min
1926
  let waveMax = -Infinity, waveMin = Infinity;
1927
  for (let i = 0; i < waveform.length; i++) {
1928
  if (waveform[i] > waveMax) waveMax = waveform[i];
1929
  if (waveform[i] < waveMin) waveMin = waveform[i];
1930
  }
 
1931
 
1932
  // Check for invalid values
1933
  if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
@@ -1935,7 +1963,7 @@ export class AudioModel {
1935
  return new Float32Array(0);
1936
  }
1937
 
1938
- // Normalize waveform to [-0.9, 0.9]
1939
  let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
1940
  if (maxVal > 0) {
1941
  for (let i = 0; i < waveform.length; i++) {
@@ -1945,8 +1973,7 @@ export class AudioModel {
1945
  console.warn('ISTFT produced all-zero waveform');
1946
  }
1947
 
1948
- const totalDecodeTime = performance.now() - decodeStart;
1949
- log(`Audio decode complete: ${totalDecodeTime.toFixed(0)}ms total (detok=${(detokEnd - detokStart).toFixed(0)}ms, istft=${(istftEnd - istftStart).toFixed(0)}ms) → ${waveform.length} samples, ${(waveform.length / 24000).toFixed(2)}s @ 24kHz`);
1950
  return waveform;
1951
  }
1952
 
 
397
  };
398
 
399
  // Helper to load ONNX model with external data
400
+ const loadOnnxWithExternalData = async (name, progress, quantSuffix = null, extraOptions = {}) => {
401
  const suffix = quantSuffix ? `_${quantSuffix}` : '';
402
  const fileName = `${name}${suffix}`;
403
  report('loading', progress, `${fileName}.onnx`);
 
405
  const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
406
  const fetchOptions = { mode: 'cors', credentials: 'omit' };
407
 
408
+ console.log(`Loading ${fileName}...`);
 
409
 
410
+ const sessionOptions = { executionProviders, ...extraOptions };
411
 
412
  const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
413
  if (!onnxResponse.ok) {
 
477
  }
478
  return { preferredOutputLocation: loc };
479
  })() : {};
480
+ this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder, decoderOpts);
481
 
482
  // Load embed_tokens weight for text embedding lookup
483
  report('loading', 30, 'embed_tokens');
 
503
  console.warn('Audio detokenizer not available:', e);
504
  }
505
 
506
+ // Load vocoder (for TTS)
507
  // On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps
508
  try {
509
  const vocoderOpts = device === 'webgpu'
510
  ? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer' } }
511
  : {};
512
+ this.vocoderSession = await loadOnnxWithExternalData('vocoder_depthformer', 95, quantConfig.vocoder, vocoderOpts);
 
 
513
  } catch (e) {
514
  console.warn('Vocoder not available:', e);
515
  }
 
951
  return '[Text generation requires full embedding support - model loaded successfully]';
952
  }
953
 
954
+ /**
955
+ * Initialize reusable vocoder tensors to reduce allocation overhead
956
+ */
957
  _initVocoderCache() {
958
  if (this._vocoderCache) return;
959
 
960
+ const numLayers = 6;
961
+ const numKvHeads = 8;
962
+ const headDim = 32;
963
+
964
+ // Pre-allocate data arrays
965
  const stepIdxData = new BigInt64Array(1);
966
  const prevTokenData = new BigInt64Array(1);
 
 
967
 
968
+ // Pre-allocate tensors that can be reused
969
  this._vocoderCache = {
970
+ hiddenTensor: null, // Created per-call since hiddenState changes
971
  stepIdxData,
972
  prevTokenData,
973
+ // Pre-create reusable tensors (ONNX Runtime reads from the data array)
 
974
  stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
975
  prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
976
+ emptyKeysData: new Float32Array(0),
977
+ emptyValuesData: new Float32Array(0),
978
+ // Reusable sampling arrays
979
+ scaledLogits: new Float32Array(2049), // codebook vocab size
980
+ indices: new Uint16Array(2049), // Use typed array for faster reset
981
+ probs: new Float32Array(64), // top-k size
 
982
  };
983
+
984
+ // Initialize indices
985
+ for (let i = 0; i < 2049; i++) {
986
+ this._vocoderCache.indices[i] = i;
987
+ }
988
  }
989
 
990
  /**
991
+ * Sample audio codes using vocoder depthformer
992
+ * Optimized to reduce tensor creation overhead
993
  * @param {Float32Array} hiddenState - [hidden_size] hidden state
994
  * @param {number} temperature - Sampling temperature
995
+ * @param {number} topK - Top-k sampling
996
  * @returns {number[]} - 8 codebook values
997
  */
998
  async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
999
+ if (!this.vocoderSession) {
1000
+ throw new Error('Vocoder not loaded');
1001
+ }
1002
+
1003
+ // Initialize cache on first call
1004
  this._initVocoderCache();
1005
  const cache = this._vocoderCache;
1006
 
1007
  const numCodebooks = 8;
1008
  const numLayers = 6;
1009
+ const numKvHeads = 8;
1010
  const headDim = 32;
 
 
 
 
 
 
 
1011
 
1012
  const codes = [];
1013
  let prevToken = 0;
1014
 
1015
+ // Create hidden state tensor (must be new since data changes)
1016
+ const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
1017
+
1018
+ // Initialize empty KV cache
1019
+ let pastKeys = new ort.Tensor(
1020
+ 'float32',
1021
+ cache.emptyKeysData,
1022
+ [numLayers, 1, 0, numKvHeads, headDim]
1023
+ );
1024
+ let pastValues = new ort.Tensor(
1025
+ 'float32',
1026
+ cache.emptyValuesData,
1027
+ [numLayers, 1, 0, numKvHeads, headDim]
1028
+ );
1029
+
1030
+ // Reuse step_idx and prev_token tensors by updating their data
1031
+ cache.stepIdxData[0] = 0n;
1032
+ cache.prevTokenData[0] = 0n;
1033
+
1034
+ for (let i = 0; i < numCodebooks; i++) {
1035
+ // Update mutable tensor data (tensor objects reuse the underlying data arrays)
1036
+ cache.stepIdxData[0] = BigInt(i);
1037
  cache.prevTokenData[0] = BigInt(prevToken);
 
 
1038
 
1039
+ const feeds = {
1040
  hidden_states: hiddenTensor,
 
1041
  step_idx: cache.stepIdxTensor,
1042
  prev_token: cache.prevTokenTensor,
1043
  past_keys: pastKeys,
1044
  past_values: pastValues,
1045
+ };
 
 
 
 
 
 
 
 
 
1046
 
1047
+ const outputs = await this.vocoderSession.run(feeds);
1048
  const logits = outputs.logits.data;
1049
+ const vocabSize = logits.length;
1050
 
1051
+ // Sample with temperature and top-k (reusing cached arrays)
1052
  let token;
1053
+ if (temperature <= 0) {
1054
+ // Greedy
1055
  token = 0;
1056
  let maxVal = logits[0];
1057
  for (let j = 1; j < vocabSize; j++) {
1058
+ if (logits[j] > maxVal) {
1059
+ maxVal = logits[j];
1060
+ token = j;
1061
+ }
1062
  }
1063
  } else {
1064
+ // Top-k sampling with reused arrays
1065
  const scaledLogits = cache.scaledLogits;
1066
  const indices = cache.indices;
1067
  const probs = cache.probs;
1068
 
1069
+ // Scale logits by temperature and find top-k in single pass
1070
+ // Use partial selection sort (O(k*n) which is fast for small k)
1071
  for (let j = 0; j < vocabSize; j++) {
1072
  scaledLogits[j] = logits[j] / temperature;
1073
  indices[j] = j;
1074
  }
1075
 
1076
+ // Partial sort to get top-k
1077
  for (let j = 0; j < topK; j++) {
1078
  let maxIdx = j;
1079
  for (let k = j + 1; k < vocabSize; k++) {
 
1081
  maxIdx = k;
1082
  }
1083
  }
1084
+ // Swap
1085
  const tmp = indices[j];
1086
  indices[j] = indices[maxIdx];
1087
  indices[maxIdx] = tmp;
 
1098
  probs[j] /= sumExp;
1099
  }
1100
 
1101
+ // Sample
1102
  const r = Math.random();
1103
  let cumsum = 0;
1104
+ token = indices[topK - 1]; // Default to last
1105
  for (let j = 0; j < topK; j++) {
1106
  cumsum += probs[j];
1107
+ if (r < cumsum) {
1108
+ token = indices[j];
1109
+ break;
1110
+ }
1111
  }
1112
  }
1113
 
1114
  codes.push(token);
1115
  prevToken = token;
1116
+
1117
+ // Update KV cache
1118
+ pastKeys = outputs.new_keys;
1119
+ pastValues = outputs.new_values;
1120
  }
1121
 
1122
  return codes;
 
1321
  throw new Error('Vocoder not loaded - required for interleaved mode');
1322
  }
1323
 
1324
+ // Timing accumulators
 
 
 
1325
  let timeAudioEncode = 0;
1326
  let timePrefill = 0;
1327
+ let timeTextDecode = 0;
1328
+ let timeAudioDecode = 0;
1329
+ let timeVocoder = 0;
1330
  let timeAudioEmbed = 0;
1331
 
1332
  // 1. Compute mel spectrogram and encode audio
 
1432
 
1433
  const startTime = performance.now();
1434
 
1435
+ for (let step = 0; step < maxNewTokens; step++) {
 
 
 
1436
  modalityLeft--;
1437
 
1438
  if (inAudioMode) {
1439
+ // Generate audio frame using depthformer
1440
  const hiddenData = hiddenStates.data;
1441
  const seqLen = hiddenStates.dims[1];
1442
  const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
1443
 
1444
  tStep = performance.now();
1445
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1446
+ timeVocoder += performance.now() - tStep;
1447
 
1448
  // Switch back to text after N audio frames (if text not done)
1449
  if (modalityLeft <= 0 && !textDone) {
 
1450
  inAudioMode = false;
1451
  modalityLeft = INTERLEAVED_N_TEXT;
1452
  }
1453
 
1454
  // Check for end of audio - first codebook == 2048 (matching liquid-audio)
1455
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1456
+ log(`End of audio at step ${step}`);
1457
+ // Set all codes to 2048 (matching liquid-audio)
1458
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1459
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1460
  }
1461
  inAudioMode = false;
1462
+ // Don't save this frame, but still feed it back
1463
  } else {
1464
+ // Save valid frame (clamped to 0-2047)
1465
  const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
1466
  audioCodes.push(clampedFrame);
1467
 
 
1470
  }
1471
 
1472
  if (audioCodes.length % 50 === 0) {
1473
+ log(`Generated ${audioCodes.length} audio frames`);
1474
  }
1475
  }
1476
 
1477
+ // Get embeddings for next step (always feed back, even for 2048 frames)
1478
  tStep = performance.now();
1479
  const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
1480
  const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
1481
 
1482
+ // Get summed embeddings for all 8 codebooks
1483
  const summedEmbeds = await this.getAudioEmbedding(audioTokens);
1484
  timeAudioEmbed += performance.now() - tStep;
1485
 
 
1488
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1489
  tStep = performance.now();
1490
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1491
+ timeAudioDecode += performance.now() - tStep;
1492
  this.updateCache(this.cache, outputs);
1493
 
1494
  } else {
1495
+ // Generate text token
1496
  const logitsData = logits.data;
1497
  const seqLen = logits.dims[1];
1498
+ // Get logits for last position - shape is [1, seq_len, vocab_size]
1499
  const lastLogits = new Float32Array(this.vocabSize);
1500
  const offset = (seqLen - 1) * this.vocabSize;
1501
  for (let i = 0; i < this.vocabSize; i++) {
 
1505
 
1506
  // Check for end of turn
1507
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1508
+ log(`End of turn at step ${step}`);
1509
  break;
1510
  }
1511
 
1512
  // Check for <|text_end|> token (130)
1513
  if (token === SPECIAL_TOKENS.TEXT_END) {
1514
+ log(`Text end at step ${step}`);
1515
  textDone = true;
1516
  }
1517
 
1518
  // Switch to audio after N text tokens OR text_end
1519
  if (modalityLeft <= 0 || textDone) {
 
1520
  inAudioMode = true;
1521
  modalityLeft = INTERLEAVED_N_AUDIO;
1522
  }
 
1528
  onToken(decodedText, token);
1529
  }
1530
 
1531
+ // Get embedding for next step
1532
  const nextEmbeds = this.getTextEmbeddings([token]);
1533
  currentLen++;
1534
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1535
  tStep = performance.now();
1536
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1537
+ timeTextDecode += performance.now() - tStep;
1538
  this.updateCache(this.cache, outputs);
1539
  }
1540
  }
1541
 
1542
  // 5. Feed <|im_end|> token to close assistant turn in cache
 
1543
  const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
1544
  currentLen++;
1545
  const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
 
1550
  // Decode with skip_special_tokens to clean up special tokens like <|text_end|>
1551
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1552
 
1553
+ // Print timing summary
1554
+ log(`=== Summary ===`);
1555
+ log(` Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, Prefill: ${timePrefill.toFixed(0)}ms`);
1556
+ log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
1557
+ log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
1558
+ log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
1559
+ log(`Text: "${text}"`);
1560
+ log(`Cache seq_len: ${this.cacheSeqLen}`);
1561
 
1562
  return { text, audioCodes };
1563
  }
 
1598
  throw new Error('Vocoder not loaded - required for interleaved mode');
1599
  }
1600
 
1601
+ // Timing accumulators
 
 
 
1602
  let timePrefill = 0;
1603
+ let timeTextDecode = 0;
1604
+ let timeAudioDecode = 0;
1605
+ let timeVocoder = 0;
1606
  let timeAudioEmbed = 0;
1607
  let tStep;
1608
 
 
1667
 
1668
  tStep = performance.now();
1669
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1670
+ timeVocoder += performance.now() - tStep;
1671
 
1672
  // Switch back to text after N audio frames (if text not done)
1673
  if (modalityLeft <= 0 && !textDone) {
 
1674
  inAudioMode = false;
1675
  modalityLeft = INTERLEAVED_N_TEXT;
1676
  }
1677
 
1678
  // Check for end of audio
1679
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1680
+ log(`End of audio at step ${step}`);
1681
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1682
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1683
  }
 
1691
  }
1692
 
1693
  if (audioCodes.length % 50 === 0) {
1694
+ log(`Generated ${audioCodes.length} audio frames`);
1695
  }
1696
  }
1697
 
 
1707
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1708
  tStep = performance.now();
1709
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1710
+ timeAudioDecode += performance.now() - tStep;
1711
  this.updateCache(this.cache, outputs);
1712
 
1713
  } else {
 
1723
 
1724
  // Check for end of turn
1725
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1726
+ log(`End of turn at step ${step}`);
1727
  break;
1728
  }
1729
 
1730
  // Check for <|text_end|> token
1731
  if (token === SPECIAL_TOKENS.TEXT_END) {
1732
+ log(`Text end at step ${step}`);
1733
  textDone = true;
1734
  }
1735
 
1736
  // Switch to audio after N text tokens OR text_end
1737
  if (modalityLeft <= 0 || textDone) {
 
1738
  inAudioMode = true;
1739
  modalityLeft = INTERLEAVED_N_AUDIO;
1740
  }
 
1752
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1753
  tStep = performance.now();
1754
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1755
+ timeTextDecode += performance.now() - tStep;
1756
  this.updateCache(this.cache, outputs);
1757
  }
1758
  }
 
1767
 
1768
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1769
 
1770
+ log(`=== Summary ===`);
1771
+ log(` Prefill: ${timePrefill.toFixed(0)}ms`);
1772
+ log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
1773
+ log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
1774
+ log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
1775
+ log(`Text: "${text}"`);
1776
+ log(`Cache seq_len: ${this.cacheSeqLen}`);
1777
 
1778
  return { text, audioCodes };
1779
  }
 
1903
  }
1904
 
1905
  const decodeStart = performance.now();
1906
+ log(`Decoding ${audioCodes.length} audio frames...`);
 
1907
 
1908
  // ISTFT parameters (fixed for this model)
1909
  const nFft = 1280;
 
1911
  const winLength = 1280;
1912
  const nFftBins = nFft / 2 + 1;
1913
 
1914
+ // Stack codes: [T, 8] -> [8, T] and add batch -> [1, 8, T]
1915
  const T = audioCodes.length;
1916
  const codesTransposed = new BigInt64Array(8 * T);
1917
  for (let t = 0; t < T; t++) {
 
1920
  }
1921
  }
1922
 
1923
+ // Run detokenizer: [1, 8, T] -> [1, T, 1282]
1924
  const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
1925
  const detokStart = performance.now();
1926
  const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
1927
  const stftFeatures = detokOutputs.stft_features;
1928
+ log(`Detokenizer: ${(performance.now() - detokStart).toFixed(0)}ms, STFT frames: ${stftFeatures.dims[1]}`);
 
 
1929
 
1930
+ // Get raw data - shape is [1, T, 1282], we need to skip batch dimension
1931
  const stftData = stftFeatures.data;
1932
+ const actualT = stftFeatures.dims[1];
1933
 
1934
+ // Convert to complex STFT: [log_magnitude | angle] -> complex
1935
  const complexStft = new Array(nFftBins);
1936
  for (let f = 0; f < nFftBins; f++) {
1937
  complexStft[f] = new Array(actualT);
 
1939
  const logMag = stftData[t * 1282 + f];
1940
  const angle = stftData[t * 1282 + nFftBins + f];
1941
  const mag = Math.exp(logMag);
1942
+ // Store as [real, imag]
1943
  complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
1944
  }
1945
  }
1946
 
1947
+ // ISTFT with 'same' padding
 
1948
  const istftStart = performance.now();
1949
  const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
1950
+ log(`ISTFT: ${(performance.now() - istftStart).toFixed(0)}ms`);
 
1951
 
1952
+ // Find max/min without spread operator (avoid stack overflow on large arrays)
1953
  let waveMax = -Infinity, waveMin = Infinity;
1954
  for (let i = 0; i < waveform.length; i++) {
1955
  if (waveform[i] > waveMax) waveMax = waveform[i];
1956
  if (waveform[i] < waveMin) waveMin = waveform[i];
1957
  }
1958
+ log('ISTFT output - length:', waveform.length, 'max:', waveMax.toFixed(4), 'min:', waveMin.toFixed(4));
1959
 
1960
  // Check for invalid values
1961
  if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
 
1963
  return new Float32Array(0);
1964
  }
1965
 
1966
+ // Normalize to [-1, 1]
1967
  let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
1968
  if (maxVal > 0) {
1969
  for (let i = 0; i < waveform.length; i++) {
 
1973
  console.warn('ISTFT produced all-zero waveform');
1974
  }
1975
 
1976
+ log(`Decoded audio: ${waveform.length} samples (${(waveform.length / 24000).toFixed(2)}s)`);
 
1977
  return waveform;
1978
  }
1979