ykhrustalev commited on
Commit
da13637
·
unverified ·
1 Parent(s): e07bd97

Handle the separate codebook step

Browse files
Files changed (2) hide show
  1. audio-model.js +136 -146
  2. main.js +3 -3
audio-model.js CHANGED
@@ -397,7 +397,7 @@ export class AudioModel {
397
  };
398
 
399
  // Helper to load ONNX model with external data
400
- const loadOnnxWithExternalData = async (name, progress, quantSuffix = null) => {
401
  const suffix = quantSuffix ? `_${quantSuffix}` : '';
402
  const fileName = `${name}${suffix}`;
403
  report('loading', progress, `${fileName}.onnx`);
@@ -405,9 +405,10 @@ export class AudioModel {
405
  const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
406
  const fetchOptions = { mode: 'cors', credentials: 'omit' };
407
 
408
- console.log(`Loading ${fileName}...`);
 
409
 
410
- const sessionOptions = { executionProviders };
411
 
412
  const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
413
  if (!onnxResponse.ok) {
@@ -464,7 +465,20 @@ export class AudioModel {
464
  };
465
 
466
  // Load decoder
467
- this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder);
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  // Load embed_tokens weight for text embedding lookup
470
  report('loading', 30, 'embed_tokens');
@@ -490,9 +504,15 @@ export class AudioModel {
490
  console.warn('Audio detokenizer not available:', e);
491
  }
492
 
493
- // Load vocoder (for TTS)
 
494
  try {
495
- this.vocoderSession = await loadOnnxWithExternalData('vocoder_depthformer', 95, quantConfig.vocoder);
 
 
 
 
 
496
  } catch (e) {
497
  console.warn('Vocoder not available:', e);
498
  }
@@ -934,129 +954,103 @@ export class AudioModel {
934
  return '[Text generation requires full embedding support - model loaded successfully]';
935
  }
936
 
937
- /**
938
- * Initialize reusable vocoder tensors to reduce allocation overhead
939
- */
940
  _initVocoderCache() {
941
  if (this._vocoderCache) return;
942
 
943
- const numLayers = 6;
944
- const numKvHeads = 8;
945
- const headDim = 32;
946
-
947
- // Pre-allocate data arrays
948
  const stepIdxData = new BigInt64Array(1);
949
  const prevTokenData = new BigInt64Array(1);
 
 
950
 
951
- // Pre-allocate tensors that can be reused
952
  this._vocoderCache = {
953
- hiddenTensor: null, // Created per-call since hiddenState changes
954
  stepIdxData,
955
  prevTokenData,
956
- // Pre-create reusable tensors (ONNX Runtime reads from the data array)
 
957
  stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
958
  prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
959
- emptyKeysData: new Float32Array(0),
960
- emptyValuesData: new Float32Array(0),
961
- // Reusable sampling arrays
962
- scaledLogits: new Float32Array(2049), // codebook vocab size
963
- indices: new Uint16Array(2049), // Use typed array for faster reset
964
- probs: new Float32Array(64), // top-k size
 
965
  };
966
-
967
- // Initialize indices
968
- for (let i = 0; i < 2049; i++) {
969
- this._vocoderCache.indices[i] = i;
970
- }
971
  }
972
 
973
  /**
974
- * Sample audio codes using vocoder depthformer
975
- * Optimized to reduce tensor creation overhead
976
  * @param {Float32Array} hiddenState - [hidden_size] hidden state
977
  * @param {number} temperature - Sampling temperature
978
- * @param {number} topK - Top-k sampling
979
  * @returns {number[]} - 8 codebook values
980
  */
981
  async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
982
- if (!this.vocoderSession) {
983
- throw new Error('Vocoder not loaded');
984
- }
985
-
986
- // Initialize cache on first call
987
  this._initVocoderCache();
988
  const cache = this._vocoderCache;
989
 
990
  const numCodebooks = 8;
991
  const numLayers = 6;
992
- const numKvHeads = 8;
993
  const headDim = 32;
 
994
 
995
- const codes = [];
996
- let prevToken = 0;
997
-
998
- // Create hidden state tensor (must be new since data changes)
999
  const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
 
 
 
 
1000
 
1001
- // Initialize empty KV cache
1002
- let pastKeys = new ort.Tensor(
1003
- 'float32',
1004
- cache.emptyKeysData,
1005
- [numLayers, 1, 0, numKvHeads, headDim]
1006
- );
1007
- let pastValues = new ort.Tensor(
1008
- 'float32',
1009
- cache.emptyValuesData,
1010
- [numLayers, 1, 0, numKvHeads, headDim]
1011
- );
1012
-
1013
- // Reuse step_idx and prev_token tensors by updating their data
1014
- cache.stepIdxData[0] = 0n;
1015
- cache.prevTokenData[0] = 0n;
1016
 
1017
- for (let i = 0; i < numCodebooks; i++) {
1018
- // Update mutable tensor data (tensor objects reuse the underlying data arrays)
1019
- cache.stepIdxData[0] = BigInt(i);
1020
  cache.prevTokenData[0] = BigInt(prevToken);
 
 
1021
 
1022
- const feeds = {
1023
  hidden_states: hiddenTensor,
 
1024
  step_idx: cache.stepIdxTensor,
1025
  prev_token: cache.prevTokenTensor,
1026
  past_keys: pastKeys,
1027
  past_values: pastValues,
1028
- };
 
 
 
 
 
 
 
 
 
1029
 
1030
- const outputs = await this.vocoderSession.run(feeds);
1031
  const logits = outputs.logits.data;
1032
- const vocabSize = logits.length;
1033
 
1034
- // Sample with temperature and top-k (reusing cached arrays)
1035
  let token;
1036
- if (temperature <= 0) {
1037
- // Greedy
1038
  token = 0;
1039
  let maxVal = logits[0];
1040
  for (let j = 1; j < vocabSize; j++) {
1041
- if (logits[j] > maxVal) {
1042
- maxVal = logits[j];
1043
- token = j;
1044
- }
1045
  }
1046
  } else {
1047
- // Top-k sampling with reused arrays
1048
  const scaledLogits = cache.scaledLogits;
1049
  const indices = cache.indices;
1050
  const probs = cache.probs;
1051
 
1052
- // Scale logits by temperature and find top-k in single pass
1053
- // Use partial selection sort (O(k*n) which is fast for small k)
1054
  for (let j = 0; j < vocabSize; j++) {
1055
  scaledLogits[j] = logits[j] / temperature;
1056
  indices[j] = j;
1057
  }
1058
 
1059
- // Partial sort to get top-k
1060
  for (let j = 0; j < topK; j++) {
1061
  let maxIdx = j;
1062
  for (let k = j + 1; k < vocabSize; k++) {
@@ -1064,7 +1058,6 @@ export class AudioModel {
1064
  maxIdx = k;
1065
  }
1066
  }
1067
- // Swap
1068
  const tmp = indices[j];
1069
  indices[j] = indices[maxIdx];
1070
  indices[maxIdx] = tmp;
@@ -1081,25 +1074,18 @@ export class AudioModel {
1081
  probs[j] /= sumExp;
1082
  }
1083
 
1084
- // Sample
1085
  const r = Math.random();
1086
  let cumsum = 0;
1087
- token = indices[topK - 1]; // Default to last
1088
  for (let j = 0; j < topK; j++) {
1089
  cumsum += probs[j];
1090
- if (r < cumsum) {
1091
- token = indices[j];
1092
- break;
1093
- }
1094
  }
1095
  }
1096
 
1097
  codes.push(token);
1098
  prevToken = token;
1099
-
1100
- // Update KV cache
1101
- pastKeys = outputs.new_keys;
1102
- pastValues = outputs.new_values;
1103
  }
1104
 
1105
  return codes;
@@ -1304,12 +1290,15 @@ export class AudioModel {
1304
  throw new Error('Vocoder not loaded - required for interleaved mode');
1305
  }
1306
 
1307
- // Timing accumulators
 
 
 
1308
  let timeAudioEncode = 0;
1309
  let timePrefill = 0;
1310
- let timeTextDecode = 0;
1311
- let timeAudioDecode = 0;
1312
- let timeVocoder = 0;
1313
  let timeAudioEmbed = 0;
1314
 
1315
  // 1. Compute mel spectrogram and encode audio
@@ -1415,36 +1404,37 @@ export class AudioModel {
1415
 
1416
  const startTime = performance.now();
1417
 
1418
- for (let step = 0; step < maxNewTokens; step++) {
 
 
 
1419
  modalityLeft--;
1420
 
1421
  if (inAudioMode) {
1422
- // Generate audio frame using depthformer
1423
  const hiddenData = hiddenStates.data;
1424
  const seqLen = hiddenStates.dims[1];
1425
  const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
1426
 
1427
  tStep = performance.now();
1428
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1429
- timeVocoder += performance.now() - tStep;
1430
 
1431
  // Switch back to text after N audio frames (if text not done)
1432
  if (modalityLeft <= 0 && !textDone) {
 
1433
  inAudioMode = false;
1434
  modalityLeft = INTERLEAVED_N_TEXT;
1435
  }
1436
 
1437
  // Check for end of audio - first codebook == 2048 (matching liquid-audio)
1438
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1439
- log(`End of audio at step ${step}`);
1440
- // Set all codes to 2048 (matching liquid-audio)
1441
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1442
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1443
  }
1444
  inAudioMode = false;
1445
- // Don't save this frame, but still feed it back
1446
  } else {
1447
- // Save valid frame (clamped to 0-2047)
1448
  const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
1449
  audioCodes.push(clampedFrame);
1450
 
@@ -1453,16 +1443,15 @@ export class AudioModel {
1453
  }
1454
 
1455
  if (audioCodes.length % 50 === 0) {
1456
- log(`Generated ${audioCodes.length} audio frames`);
1457
  }
1458
  }
1459
 
1460
- // Get embeddings for next step (always feed back, even for 2048 frames)
1461
  tStep = performance.now();
1462
  const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
1463
  const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
1464
 
1465
- // Get summed embeddings for all 8 codebooks
1466
  const summedEmbeds = await this.getAudioEmbedding(audioTokens);
1467
  timeAudioEmbed += performance.now() - tStep;
1468
 
@@ -1471,14 +1460,13 @@ export class AudioModel {
1471
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1472
  tStep = performance.now();
1473
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1474
- timeAudioDecode += performance.now() - tStep;
1475
  this.updateCache(this.cache, outputs);
1476
 
1477
  } else {
1478
- // Generate text token
1479
  const logitsData = logits.data;
1480
  const seqLen = logits.dims[1];
1481
- // Get logits for last position - shape is [1, seq_len, vocab_size]
1482
  const lastLogits = new Float32Array(this.vocabSize);
1483
  const offset = (seqLen - 1) * this.vocabSize;
1484
  for (let i = 0; i < this.vocabSize; i++) {
@@ -1488,18 +1476,19 @@ export class AudioModel {
1488
 
1489
  // Check for end of turn
1490
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1491
- log(`End of turn at step ${step}`);
1492
  break;
1493
  }
1494
 
1495
  // Check for <|text_end|> token (130)
1496
  if (token === SPECIAL_TOKENS.TEXT_END) {
1497
- log(`Text end at step ${step}`);
1498
  textDone = true;
1499
  }
1500
 
1501
  // Switch to audio after N text tokens OR text_end
1502
  if (modalityLeft <= 0 || textDone) {
 
1503
  inAudioMode = true;
1504
  modalityLeft = INTERLEAVED_N_AUDIO;
1505
  }
@@ -1511,18 +1500,19 @@ export class AudioModel {
1511
  onToken(decodedText, token);
1512
  }
1513
 
1514
- // Get embedding for next step
1515
  const nextEmbeds = this.getTextEmbeddings([token]);
1516
  currentLen++;
1517
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1518
  tStep = performance.now();
1519
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1520
- timeTextDecode += performance.now() - tStep;
1521
  this.updateCache(this.cache, outputs);
1522
  }
1523
  }
1524
 
1525
  // 5. Feed <|im_end|> token to close assistant turn in cache
 
1526
  const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
1527
  currentLen++;
1528
  const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
@@ -1533,14 +1523,11 @@ export class AudioModel {
1533
  // Decode with skip_special_tokens to clean up special tokens like <|text_end|>
1534
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1535
 
1536
- // Print timing summary
1537
- log(`=== Summary ===`);
1538
- log(` Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, Prefill: ${timePrefill.toFixed(0)}ms`);
1539
- log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
1540
- log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
1541
- log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
1542
- log(`Text: "${text}"`);
1543
- log(`Cache seq_len: ${this.cacheSeqLen}`);
1544
 
1545
  return { text, audioCodes };
1546
  }
@@ -1581,11 +1568,14 @@ export class AudioModel {
1581
  throw new Error('Vocoder not loaded - required for interleaved mode');
1582
  }
1583
 
1584
- // Timing accumulators
 
 
 
1585
  let timePrefill = 0;
1586
- let timeTextDecode = 0;
1587
- let timeAudioDecode = 0;
1588
- let timeVocoder = 0;
1589
  let timeAudioEmbed = 0;
1590
  let tStep;
1591
 
@@ -1650,17 +1640,18 @@ export class AudioModel {
1650
 
1651
  tStep = performance.now();
1652
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1653
- timeVocoder += performance.now() - tStep;
1654
 
1655
  // Switch back to text after N audio frames (if text not done)
1656
  if (modalityLeft <= 0 && !textDone) {
 
1657
  inAudioMode = false;
1658
  modalityLeft = INTERLEAVED_N_TEXT;
1659
  }
1660
 
1661
  // Check for end of audio
1662
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1663
- log(`End of audio at step ${step}`);
1664
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1665
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1666
  }
@@ -1674,7 +1665,7 @@ export class AudioModel {
1674
  }
1675
 
1676
  if (audioCodes.length % 50 === 0) {
1677
- log(`Generated ${audioCodes.length} audio frames`);
1678
  }
1679
  }
1680
 
@@ -1690,7 +1681,7 @@ export class AudioModel {
1690
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1691
  tStep = performance.now();
1692
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1693
- timeAudioDecode += performance.now() - tStep;
1694
  this.updateCache(this.cache, outputs);
1695
 
1696
  } else {
@@ -1706,18 +1697,19 @@ export class AudioModel {
1706
 
1707
  // Check for end of turn
1708
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1709
- log(`End of turn at step ${step}`);
1710
  break;
1711
  }
1712
 
1713
  // Check for <|text_end|> token
1714
  if (token === SPECIAL_TOKENS.TEXT_END) {
1715
- log(`Text end at step ${step}`);
1716
  textDone = true;
1717
  }
1718
 
1719
  // Switch to audio after N text tokens OR text_end
1720
  if (modalityLeft <= 0 || textDone) {
 
1721
  inAudioMode = true;
1722
  modalityLeft = INTERLEAVED_N_AUDIO;
1723
  }
@@ -1735,7 +1727,7 @@ export class AudioModel {
1735
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1736
  tStep = performance.now();
1737
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1738
- timeTextDecode += performance.now() - tStep;
1739
  this.updateCache(this.cache, outputs);
1740
  }
1741
  }
@@ -1750,13 +1742,9 @@ export class AudioModel {
1750
 
1751
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1752
 
1753
- log(`=== Summary ===`);
1754
- log(` Prefill: ${timePrefill.toFixed(0)}ms`);
1755
- log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
1756
- log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
1757
- log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
1758
- log(`Text: "${text}"`);
1759
- log(`Cache seq_len: ${this.cacheSeqLen}`);
1760
 
1761
  return { text, audioCodes };
1762
  }
@@ -1886,7 +1874,8 @@ export class AudioModel {
1886
  }
1887
 
1888
  const decodeStart = performance.now();
1889
- log(`Decoding ${audioCodes.length} audio frames...`);
 
1890
 
1891
  // ISTFT parameters (fixed for this model)
1892
  const nFft = 1280;
@@ -1894,7 +1883,7 @@ export class AudioModel {
1894
  const winLength = 1280;
1895
  const nFftBins = nFft / 2 + 1;
1896
 
1897
- // Stack codes: [T, 8] -> [8, T] and add batch -> [1, 8, T]
1898
  const T = audioCodes.length;
1899
  const codesTransposed = new BigInt64Array(8 * T);
1900
  for (let t = 0; t < T; t++) {
@@ -1903,18 +1892,18 @@ export class AudioModel {
1903
  }
1904
  }
1905
 
1906
- // Run detokenizer: [1, 8, T] -> [1, T, 1282]
1907
  const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
1908
  const detokStart = performance.now();
1909
  const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
1910
  const stftFeatures = detokOutputs.stft_features;
1911
- log(`Detokenizer: ${(performance.now() - detokStart).toFixed(0)}ms, STFT frames: ${stftFeatures.dims[1]}`);
 
 
1912
 
1913
- // Get raw data - shape is [1, T, 1282], we need to skip batch dimension
1914
  const stftData = stftFeatures.data;
1915
- const actualT = stftFeatures.dims[1];
1916
 
1917
- // Convert to complex STFT: [log_magnitude | angle] -> complex
1918
  const complexStft = new Array(nFftBins);
1919
  for (let f = 0; f < nFftBins; f++) {
1920
  complexStft[f] = new Array(actualT);
@@ -1922,23 +1911,23 @@ export class AudioModel {
1922
  const logMag = stftData[t * 1282 + f];
1923
  const angle = stftData[t * 1282 + nFftBins + f];
1924
  const mag = Math.exp(logMag);
1925
- // Store as [real, imag]
1926
  complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
1927
  }
1928
  }
1929
 
1930
- // ISTFT with 'same' padding
 
1931
  const istftStart = performance.now();
1932
  const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
1933
- log(`ISTFT: ${(performance.now() - istftStart).toFixed(0)}ms`);
 
1934
 
1935
- // Find max/min without spread operator (avoid stack overflow on large arrays)
1936
  let waveMax = -Infinity, waveMin = Infinity;
1937
  for (let i = 0; i < waveform.length; i++) {
1938
  if (waveform[i] > waveMax) waveMax = waveform[i];
1939
  if (waveform[i] < waveMin) waveMin = waveform[i];
1940
  }
1941
- log('ISTFT output - length:', waveform.length, 'max:', waveMax.toFixed(4), 'min:', waveMin.toFixed(4));
1942
 
1943
  // Check for invalid values
1944
  if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
@@ -1946,7 +1935,7 @@ export class AudioModel {
1946
  return new Float32Array(0);
1947
  }
1948
 
1949
- // Normalize to [-1, 1]
1950
  let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
1951
  if (maxVal > 0) {
1952
  for (let i = 0; i < waveform.length; i++) {
@@ -1956,7 +1945,8 @@ export class AudioModel {
1956
  console.warn('ISTFT produced all-zero waveform');
1957
  }
1958
 
1959
- log(`Decoded audio: ${waveform.length} samples (${(waveform.length / 24000).toFixed(2)}s)`);
 
1960
  return waveform;
1961
  }
1962
 
 
397
  };
398
 
399
  // Helper to load ONNX model with external data
400
+ const loadOnnxWithExternalData = async (name, progress, quantSuffix = null, epOverride = null, extraOptions = {}) => {
401
  const suffix = quantSuffix ? `_${quantSuffix}` : '';
402
  const fileName = `${name}${suffix}`;
403
  report('loading', progress, `${fileName}.onnx`);
 
405
  const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
406
  const fetchOptions = { mode: 'cors', credentials: 'omit' };
407
 
408
+ const ep = epOverride || executionProviders;
409
+ console.log(`Loading ${fileName} (EP: ${JSON.stringify(ep)})...`);
410
 
411
+ const sessionOptions = { executionProviders: ep, ...extraOptions };
412
 
413
  const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
414
  if (!onnxResponse.ok) {
 
465
  };
466
 
467
  // Load decoder
468
+ // On WebGPU: keep KV cache outputs on GPU to avoid GPU→CPU→GPU roundtrips between steps
469
+ const decoderOpts = device === 'webgpu' ? (() => {
470
+ const loc = {};
471
+ for (let i = 0; i < this.layerTypes.length; i++) {
472
+ if (this.layerTypes[i] === 'conv') {
473
+ loc[`present_conv.${i}`] = 'gpu-buffer';
474
+ } else {
475
+ loc[`present.${i}.key`] = 'gpu-buffer';
476
+ loc[`present.${i}.value`] = 'gpu-buffer';
477
+ }
478
+ }
479
+ return { preferredOutputLocation: loc };
480
+ })() : {};
481
+ this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder, null, decoderOpts);
482
 
483
  // Load embed_tokens weight for text embedding lookup
484
  report('loading', 30, 'embed_tokens');
 
504
  console.warn('Audio detokenizer not available:', e);
505
  }
506
 
507
+ // Load vocoder/depthformer (for TTS) — per-step model (8 calls per frame)
508
+ // On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps
509
  try {
510
+ const vocoderOpts = device === 'webgpu'
511
+ ? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer', depth_slices: 'gpu-buffer' } }
512
+ : {};
513
+ this.vocoderSession = await loadOnnxWithExternalData(
514
+ 'vocoder_depthformer', 95, quantConfig.vocoder, null, vocoderOpts,
515
+ );
516
  } catch (e) {
517
  console.warn('Vocoder not available:', e);
518
  }
 
954
  return '[Text generation requires full embedding support - model loaded successfully]';
955
  }
956
 
 
 
 
957
  _initVocoderCache() {
958
  if (this._vocoderCache) return;
959
 
960
+ const vocabSize = 2049;
 
 
 
 
961
  const stepIdxData = new BigInt64Array(1);
962
  const prevTokenData = new BigInt64Array(1);
963
+ const seqlensKData = new Int32Array(1);
964
+ const totalSeqLenData = new Int32Array(1);
965
 
 
966
  this._vocoderCache = {
 
967
  stepIdxData,
968
  prevTokenData,
969
+ seqlensKData,
970
+ totalSeqLenData,
971
  stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
972
  prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
973
+ seqlensKTensor: new ort.Tensor('int32', seqlensKData, [1]),
974
+ totalSeqLenTensor: new ort.Tensor('int32', totalSeqLenData, []),
975
+ emptyData: new Float32Array(0),
976
+ // Pre-allocated sampling arrays
977
+ scaledLogits: new Float32Array(vocabSize),
978
+ indices: new Uint16Array(vocabSize),
979
+ probs: new Float32Array(64),
980
  };
 
 
 
 
 
981
  }
982
 
983
  /**
984
+ * Sample audio codes using per-step depthformer (8 session.run calls).
985
+ * Uses GroupQueryAttention with BNSH KV cache format.
986
  * @param {Float32Array} hiddenState - [hidden_size] hidden state
987
  * @param {number} temperature - Sampling temperature
988
+ * @param {number} topK - Top-k sampling (0 = greedy)
989
  * @returns {number[]} - 8 codebook values
990
  */
991
  async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
 
 
 
 
 
992
  this._initVocoderCache();
993
  const cache = this._vocoderCache;
994
 
995
  const numCodebooks = 8;
996
  const numLayers = 6;
997
+ const numKVHeads = 8;
998
  const headDim = 32;
999
+ const vocabSize = 2049;
1000
 
 
 
 
 
1001
  const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
1002
+ // BNSH format: [layers, batch, heads, seq_len, head_dim]
1003
+ let pastKeys = new ort.Tensor('float32', cache.emptyData, [numLayers, 1, numKVHeads, 0, headDim]);
1004
+ let pastValues = new ort.Tensor('float32', cache.emptyData, [numLayers, 1, numKVHeads, 0, headDim]);
1005
+ let depthSlices = new ort.Tensor('float32', new Float32Array(numCodebooks * 1024), [1, numCodebooks, 1024]);
1006
 
1007
+ const codes = [];
1008
+ let prevToken = 0;
 
 
 
 
 
 
 
 
 
 
 
 
 
1009
 
1010
+ for (let step = 0; step < numCodebooks; step++) {
1011
+ cache.stepIdxData[0] = BigInt(step);
 
1012
  cache.prevTokenData[0] = BigInt(prevToken);
1013
+ cache.seqlensKData[0] = step;
1014
+ cache.totalSeqLenData[0] = step + 1;
1015
 
1016
+ const outputs = await this.vocoderSession.run({
1017
  hidden_states: hiddenTensor,
1018
+ depth_slices_in: depthSlices,
1019
  step_idx: cache.stepIdxTensor,
1020
  prev_token: cache.prevTokenTensor,
1021
  past_keys: pastKeys,
1022
  past_values: pastValues,
1023
+ seqlens_k: cache.seqlensKTensor,
1024
+ total_seq_len: cache.totalSeqLenTensor,
1025
+ });
1026
+
1027
+ if (step === 0) {
1028
+ depthSlices = outputs.depth_slices;
1029
+ }
1030
+
1031
+ pastKeys = outputs.new_keys;
1032
+ pastValues = outputs.new_values;
1033
 
 
1034
  const logits = outputs.logits.data;
 
1035
 
 
1036
  let token;
1037
+ if (temperature <= 0 || topK <= 1) {
 
1038
  token = 0;
1039
  let maxVal = logits[0];
1040
  for (let j = 1; j < vocabSize; j++) {
1041
+ if (logits[j] > maxVal) { maxVal = logits[j]; token = j; }
 
 
 
1042
  }
1043
  } else {
 
1044
  const scaledLogits = cache.scaledLogits;
1045
  const indices = cache.indices;
1046
  const probs = cache.probs;
1047
 
 
 
1048
  for (let j = 0; j < vocabSize; j++) {
1049
  scaledLogits[j] = logits[j] / temperature;
1050
  indices[j] = j;
1051
  }
1052
 
1053
+ // Partial selection sort for top-k: O(k*n) vs O(n log n) full sort
1054
  for (let j = 0; j < topK; j++) {
1055
  let maxIdx = j;
1056
  for (let k = j + 1; k < vocabSize; k++) {
 
1058
  maxIdx = k;
1059
  }
1060
  }
 
1061
  const tmp = indices[j];
1062
  indices[j] = indices[maxIdx];
1063
  indices[maxIdx] = tmp;
 
1074
  probs[j] /= sumExp;
1075
  }
1076
 
1077
+ // Sample from cumulative distribution
1078
  const r = Math.random();
1079
  let cumsum = 0;
1080
+ token = indices[topK - 1];
1081
  for (let j = 0; j < topK; j++) {
1082
  cumsum += probs[j];
1083
+ if (r < cumsum) { token = indices[j]; break; }
 
 
 
1084
  }
1085
  }
1086
 
1087
  codes.push(token);
1088
  prevToken = token;
 
 
 
 
1089
  }
1090
 
1091
  return codes;
 
1290
  throw new Error('Vocoder not loaded - required for interleaved mode');
1291
  }
1292
 
1293
+ // Timing accumulators (names match liquid-audio architecture):
1294
+ // - lfmText/lfmAudio: self.lfm() decoder calls (text vs audio steps)
1295
+ // - depthformer: self._sample_audio_frame() — depth_linear + 8× depthformer
1296
+ // - audioEmbed: self.audio_embedding(...).sum() — feedback embedding
1297
  let timeAudioEncode = 0;
1298
  let timePrefill = 0;
1299
+ let timeLfmText = 0;
1300
+ let timeLfmAudio = 0;
1301
+ let timeDepthformer = 0;
1302
  let timeAudioEmbed = 0;
1303
 
1304
  // 1. Compute mel spectrogram and encode audio
 
1404
 
1405
  const startTime = performance.now();
1406
 
1407
+ log(`Generation loop: max ${maxNewTokens} steps, starting in TEXT mode`);
1408
+
1409
+ let step = 0;
1410
+ for (; step < maxNewTokens; step++) {
1411
  modalityLeft--;
1412
 
1413
  if (inAudioMode) {
1414
+ // === AUDIO STEP: extract hidden_states → depthformer → 8 codebook tokens ===
1415
  const hiddenData = hiddenStates.data;
1416
  const seqLen = hiddenStates.dims[1];
1417
  const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
1418
 
1419
  tStep = performance.now();
1420
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1421
+ timeDepthformer += performance.now() - tStep;
1422
 
1423
  // Switch back to text after N audio frames (if text not done)
1424
  if (modalityLeft <= 0 && !textDone) {
1425
+ log(`→ AUDIO→TEXT (after ${INTERLEAVED_N_AUDIO} audio frames, ${audioCodes.length} total)`);
1426
  inAudioMode = false;
1427
  modalityLeft = INTERLEAVED_N_TEXT;
1428
  }
1429
 
1430
  // Check for end of audio - first codebook == 2048 (matching liquid-audio)
1431
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1432
+ log(`→ END_OF_AUDIO at step ${step} (${audioCodes.length} frames collected)`);
 
1433
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1434
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1435
  }
1436
  inAudioMode = false;
 
1437
  } else {
 
1438
  const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
1439
  audioCodes.push(clampedFrame);
1440
 
 
1443
  }
1444
 
1445
  if (audioCodes.length % 50 === 0) {
1446
+ log(` Audio frames: ${audioCodes.length}`);
1447
  }
1448
  }
1449
 
1450
+ // === FEEDBACK: embed 8 codes (summed) feed back to LFM decoder ===
1451
  tStep = performance.now();
1452
  const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
1453
  const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
1454
 
 
1455
  const summedEmbeds = await this.getAudioEmbedding(audioTokens);
1456
  timeAudioEmbed += performance.now() - tStep;
1457
 
 
1460
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1461
  tStep = performance.now();
1462
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1463
+ timeLfmAudio += performance.now() - tStep;
1464
  this.updateCache(this.cache, outputs);
1465
 
1466
  } else {
1467
+ // === TEXT STEP: logits → sample text token ===
1468
  const logitsData = logits.data;
1469
  const seqLen = logits.dims[1];
 
1470
  const lastLogits = new Float32Array(this.vocabSize);
1471
  const offset = (seqLen - 1) * this.vocabSize;
1472
  for (let i = 0; i < this.vocabSize; i++) {
 
1476
 
1477
  // Check for end of turn
1478
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1479
+ log(`→ END_OF_TURN at step ${step} (${textTokens.length} text tokens, ${audioCodes.length} audio frames)`);
1480
  break;
1481
  }
1482
 
1483
  // Check for <|text_end|> token (130)
1484
  if (token === SPECIAL_TOKENS.TEXT_END) {
1485
+ log(`→ TEXT_END at step ${step}: audio-only phase begins`);
1486
  textDone = true;
1487
  }
1488
 
1489
  // Switch to audio after N text tokens OR text_end
1490
  if (modalityLeft <= 0 || textDone) {
1491
+ log(`→ TEXT→AUDIO${textDone ? ' (text_done)' : ''} at step ${step}`);
1492
  inAudioMode = true;
1493
  modalityLeft = INTERLEAVED_N_AUDIO;
1494
  }
 
1500
  onToken(decodedText, token);
1501
  }
1502
 
1503
+ // === FEEDBACK: embed text token → feed back to LFM decoder ===
1504
  const nextEmbeds = this.getTextEmbeddings([token]);
1505
  currentLen++;
1506
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1507
  tStep = performance.now();
1508
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1509
+ timeLfmText += performance.now() - tStep;
1510
  this.updateCache(this.cache, outputs);
1511
  }
1512
  }
1513
 
1514
  // 5. Feed <|im_end|> token to close assistant turn in cache
1515
+
1516
  const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
1517
  currentLen++;
1518
  const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
 
1523
  // Decode with skip_special_tokens to clean up special tokens like <|text_end|>
1524
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1525
 
1526
+ const totalGenTime = performance.now() - startTime;
1527
+
1528
+ log(`Done: ${step} steps in ${totalGenTime.toFixed(0)}ms | ${textTokens.length} text tokens, ${audioCodes.length} audio frames (~${(audioCodes.length / 75).toFixed(1)}s audio)`);
1529
+ log(`Timing: mel=${timeMel.toFixed(0)}ms, audioEnc=${timeAudioEncode.toFixed(0)}ms, prefill=${timePrefill.toFixed(0)}ms, lfmText=${timeLfmText.toFixed(0)}ms, lfmAudio=${timeLfmAudio.toFixed(0)}ms, depthformer=${timeDepthformer.toFixed(0)}ms, audioEmbed=${timeAudioEmbed.toFixed(0)}ms`);
1530
+ log(`Text: "${text}" | cache_seq_len=${this.cacheSeqLen}`);
 
 
 
1531
 
1532
  return { text, audioCodes };
1533
  }
 
1568
  throw new Error('Vocoder not loaded - required for interleaved mode');
1569
  }
1570
 
1571
+ // Timing accumulators (names match liquid-audio architecture):
1572
+ // - lfmText/lfmAudio: self.lfm() decoder calls (text vs audio steps)
1573
+ // - depthformer: self._sample_audio_frame() — depth_linear + 8× depthformer
1574
+ // - audioEmbed: self.audio_embedding(...).sum() — feedback embedding
1575
  let timePrefill = 0;
1576
+ let timeLfmText = 0;
1577
+ let timeLfmAudio = 0;
1578
+ let timeDepthformer = 0;
1579
  let timeAudioEmbed = 0;
1580
  let tStep;
1581
 
 
1640
 
1641
  tStep = performance.now();
1642
  const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
1643
+ timeDepthformer += performance.now() - tStep;
1644
 
1645
  // Switch back to text after N audio frames (if text not done)
1646
  if (modalityLeft <= 0 && !textDone) {
1647
+ log(`→ AUDIO→TEXT (after ${INTERLEAVED_N_AUDIO} audio frames, ${audioCodes.length} total)`);
1648
  inAudioMode = false;
1649
  modalityLeft = INTERLEAVED_N_TEXT;
1650
  }
1651
 
1652
  // Check for end of audio
1653
  if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
1654
+ log(`→ END_OF_AUDIO at step ${step} (${audioCodes.length} frames collected)`);
1655
  for (let i = 0; i < NUM_CODEBOOKS; i++) {
1656
  frameCodes[i] = END_OF_AUDIO_TOKEN;
1657
  }
 
1665
  }
1666
 
1667
  if (audioCodes.length % 50 === 0) {
1668
+ log(` Audio frames: ${audioCodes.length}`);
1669
  }
1670
  }
1671
 
 
1681
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1682
  tStep = performance.now();
1683
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1684
+ timeLfmAudio += performance.now() - tStep;
1685
  this.updateCache(this.cache, outputs);
1686
 
1687
  } else {
 
1697
 
1698
  // Check for end of turn
1699
  if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
1700
+ log(`→ END_OF_TURN at step ${step} (${textTokens.length} text tokens, ${audioCodes.length} audio frames)`);
1701
  break;
1702
  }
1703
 
1704
  // Check for <|text_end|> token
1705
  if (token === SPECIAL_TOKENS.TEXT_END) {
1706
+ log(`→ TEXT_END at step ${step}: audio-only phase begins`);
1707
  textDone = true;
1708
  }
1709
 
1710
  // Switch to audio after N text tokens OR text_end
1711
  if (modalityLeft <= 0 || textDone) {
1712
+ log(`→ TEXT→AUDIO${textDone ? ' (text_done)' : ''} at step ${step}`);
1713
  inAudioMode = true;
1714
  modalityLeft = INTERLEAVED_N_AUDIO;
1715
  }
 
1727
  const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
1728
  tStep = performance.now();
1729
  ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
1730
+ timeLfmText += performance.now() - tStep;
1731
  this.updateCache(this.cache, outputs);
1732
  }
1733
  }
 
1742
 
1743
  const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
1744
 
1745
+ log(`Done: ${textTokens.length} text tokens, ${audioCodes.length} audio frames (~${(audioCodes.length / 75).toFixed(1)}s audio)`);
1746
+ log(`Timing: prefill=${timePrefill.toFixed(0)}ms, lfmText=${timeLfmText.toFixed(0)}ms, lfmAudio=${timeLfmAudio.toFixed(0)}ms, depthformer=${timeDepthformer.toFixed(0)}ms, audioEmbed=${timeAudioEmbed.toFixed(0)}ms`);
1747
+ log(`Text: "${text}" | cache_seq_len=${this.cacheSeqLen}`);
 
 
 
 
1748
 
1749
  return { text, audioCodes };
1750
  }
 
1874
  }
1875
 
1876
  const decodeStart = performance.now();
1877
+
1878
+ log(`Audio decode: ${audioCodes.length} frames → waveform`);
1879
 
1880
  // ISTFT parameters (fixed for this model)
1881
  const nFft = 1280;
 
1883
  const winLength = 1280;
1884
  const nFftBins = nFft / 2 + 1;
1885
 
1886
+ // Transpose codes [T, 8] [1, 8, T] for ONNX input
1887
  const T = audioCodes.length;
1888
  const codesTransposed = new BigInt64Array(8 * T);
1889
  for (let t = 0; t < T; t++) {
 
1892
  }
1893
  }
1894
 
1895
+ // Run detokenizer ONNX: [1, 8, T] [1, 6T, 1282]
1896
  const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
1897
  const detokStart = performance.now();
1898
  const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
1899
  const stftFeatures = detokOutputs.stft_features;
1900
+ const actualT = stftFeatures.dims[1];
1901
+ const detokEnd = performance.now();
1902
+ log(` Detokenizer: [1,8,${T}] → [1,${actualT},1282] in ${(detokEnd - detokStart).toFixed(0)}ms`);
1903
 
1904
+ // Split into magnitude + angle complex spectrogram
1905
  const stftData = stftFeatures.data;
 
1906
 
 
1907
  const complexStft = new Array(nFftBins);
1908
  for (let f = 0; f < nFftBins; f++) {
1909
  complexStft[f] = new Array(actualT);
 
1911
  const logMag = stftData[t * 1282 + f];
1912
  const angle = stftData[t * 1282 + nFftBins + f];
1913
  const mag = Math.exp(logMag);
 
1914
  complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
1915
  }
1916
  }
1917
 
1918
+ // ISTFT (inverse Short-Time Fourier Transform) → waveform
1919
+ const pad = (winLength - hopLength) / 2;
1920
  const istftStart = performance.now();
1921
  const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
1922
+ const istftEnd = performance.now();
1923
+ log(` ISTFT: ${actualT} frames → ${waveform.length} samples in ${(istftEnd - istftStart).toFixed(0)}ms`);
1924
 
1925
+ // Find max/min
1926
  let waveMax = -Infinity, waveMin = Infinity;
1927
  for (let i = 0; i < waveform.length; i++) {
1928
  if (waveform[i] > waveMax) waveMax = waveform[i];
1929
  if (waveform[i] < waveMin) waveMin = waveform[i];
1930
  }
 
1931
 
1932
  // Check for invalid values
1933
  if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
 
1935
  return new Float32Array(0);
1936
  }
1937
 
1938
+ // Normalize waveform to [-0.9, 0.9]
1939
  let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
1940
  if (maxVal > 0) {
1941
  for (let i = 0; i < waveform.length; i++) {
 
1945
  console.warn('ISTFT produced all-zero waveform');
1946
  }
1947
 
1948
+ const totalDecodeTime = performance.now() - decodeStart;
1949
+ log(`Audio decode complete: ${totalDecodeTime.toFixed(0)}ms total (detok=${(detokEnd - detokStart).toFixed(0)}ms, istft=${(istftEnd - istftStart).toFixed(0)}ms) → ${waveform.length} samples, ${(waveform.length / 24000).toFixed(2)}s @ 24kHz`);
1950
  return waveform;
1951
  }
1952
 
main.js CHANGED
@@ -6,13 +6,13 @@
6
 
7
  import { AudioModel, loadAudioFile, clearModelCache, getCacheInfo } from './audio-model.js';
8
 
9
- // HuggingFace model URL
10
- const MODEL_URL = 'https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B-ONNX/resolve/main';
11
 
12
  // Model configurations
13
  const MODELS = {
14
  'LFM2.5-Audio-1.5B-Q4': {
15
- path: MODEL_URL,
16
  label: 'LFM2.5-Audio-1.5B Q4 (~1.6 GB)',
17
  quantization: {
18
  decoder: 'q4',
 
6
 
7
  import { AudioModel, loadAudioFile, clearModelCache, getCacheInfo } from './audio-model.js';
8
 
9
+ // Model path - local directory
10
+ const MODEL_PATH = './LFM2.5-Audio-1.5B-ONNX';
11
 
12
  // Model configurations
13
  const MODELS = {
14
  'LFM2.5-Audio-1.5B-Q4': {
15
+ path: MODEL_PATH,
16
  label: 'LFM2.5-Audio-1.5B Q4 (~1.6 GB)',
17
  quantization: {
18
  decoder: 'q4',