correct the input
Browse files- audio-model.js +146 -119
audio-model.js
CHANGED
|
@@ -397,7 +397,7 @@ export class AudioModel {
|
|
| 397 |
};
|
| 398 |
|
| 399 |
// Helper to load ONNX model with external data
|
| 400 |
-
const loadOnnxWithExternalData = async (name, progress, quantSuffix = null,
|
| 401 |
const suffix = quantSuffix ? `_${quantSuffix}` : '';
|
| 402 |
const fileName = `${name}${suffix}`;
|
| 403 |
report('loading', progress, `${fileName}.onnx`);
|
|
@@ -405,10 +405,9 @@ export class AudioModel {
|
|
| 405 |
const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
|
| 406 |
const fetchOptions = { mode: 'cors', credentials: 'omit' };
|
| 407 |
|
| 408 |
-
|
| 409 |
-
console.log(`Loading ${fileName} (EP: ${JSON.stringify(ep)})...`);
|
| 410 |
|
| 411 |
-
const sessionOptions = { executionProviders
|
| 412 |
|
| 413 |
const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
|
| 414 |
if (!onnxResponse.ok) {
|
|
@@ -478,7 +477,7 @@ export class AudioModel {
|
|
| 478 |
}
|
| 479 |
return { preferredOutputLocation: loc };
|
| 480 |
})() : {};
|
| 481 |
-
this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder,
|
| 482 |
|
| 483 |
// Load embed_tokens weight for text embedding lookup
|
| 484 |
report('loading', 30, 'embed_tokens');
|
|
@@ -504,15 +503,13 @@ export class AudioModel {
|
|
| 504 |
console.warn('Audio detokenizer not available:', e);
|
| 505 |
}
|
| 506 |
|
| 507 |
-
// Load vocoder
|
| 508 |
// On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps
|
| 509 |
try {
|
| 510 |
const vocoderOpts = device === 'webgpu'
|
| 511 |
? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer' } }
|
| 512 |
: {};
|
| 513 |
-
this.vocoderSession = await loadOnnxWithExternalData(
|
| 514 |
-
'vocoder_depthformer', 95, quantConfig.vocoder, null, vocoderOpts,
|
| 515 |
-
);
|
| 516 |
} catch (e) {
|
| 517 |
console.warn('Vocoder not available:', e);
|
| 518 |
}
|
|
@@ -954,103 +951,129 @@ export class AudioModel {
|
|
| 954 |
return '[Text generation requires full embedding support - model loaded successfully]';
|
| 955 |
}
|
| 956 |
|
|
|
|
|
|
|
|
|
|
| 957 |
_initVocoderCache() {
|
| 958 |
if (this._vocoderCache) return;
|
| 959 |
|
| 960 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
const stepIdxData = new BigInt64Array(1);
|
| 962 |
const prevTokenData = new BigInt64Array(1);
|
| 963 |
-
const seqlensKData = new Int32Array(1);
|
| 964 |
-
const totalSeqLenData = new Int32Array(1);
|
| 965 |
|
|
|
|
| 966 |
this._vocoderCache = {
|
|
|
|
| 967 |
stepIdxData,
|
| 968 |
prevTokenData,
|
| 969 |
-
|
| 970 |
-
totalSeqLenData,
|
| 971 |
stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
|
| 972 |
prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
//
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
probs: new Float32Array(64),
|
| 980 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
}
|
| 982 |
|
| 983 |
/**
|
| 984 |
-
* Sample audio codes using
|
| 985 |
-
*
|
| 986 |
* @param {Float32Array} hiddenState - [hidden_size] hidden state
|
| 987 |
* @param {number} temperature - Sampling temperature
|
| 988 |
-
* @param {number} topK - Top-k sampling
|
| 989 |
* @returns {number[]} - 8 codebook values
|
| 990 |
*/
|
| 991 |
async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
this._initVocoderCache();
|
| 993 |
const cache = this._vocoderCache;
|
| 994 |
|
| 995 |
const numCodebooks = 8;
|
| 996 |
const numLayers = 6;
|
| 997 |
-
const
|
| 998 |
const headDim = 32;
|
| 999 |
-
const vocabSize = 2049;
|
| 1000 |
-
|
| 1001 |
-
const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
|
| 1002 |
-
// BNSH format: [layers, batch, heads, seq_len, head_dim]
|
| 1003 |
-
let pastKeys = new ort.Tensor('float32', cache.emptyData, [numLayers, 1, numKVHeads, 0, headDim]);
|
| 1004 |
-
let pastValues = new ort.Tensor('float32', cache.emptyData, [numLayers, 1, numKVHeads, 0, headDim]);
|
| 1005 |
-
let depthSlices = new ort.Tensor('float32', new Float32Array(numCodebooks * 1024), [1, numCodebooks, 1024]);
|
| 1006 |
|
| 1007 |
const codes = [];
|
| 1008 |
let prevToken = 0;
|
| 1009 |
|
| 1010 |
-
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
cache.prevTokenData[0] = BigInt(prevToken);
|
| 1013 |
-
cache.seqlensKData[0] = step;
|
| 1014 |
-
cache.totalSeqLenData[0] = step + 1;
|
| 1015 |
|
| 1016 |
-
const
|
| 1017 |
hidden_states: hiddenTensor,
|
| 1018 |
-
depth_slices_in: depthSlices,
|
| 1019 |
step_idx: cache.stepIdxTensor,
|
| 1020 |
prev_token: cache.prevTokenTensor,
|
| 1021 |
past_keys: pastKeys,
|
| 1022 |
past_values: pastValues,
|
| 1023 |
-
|
| 1024 |
-
total_seq_len: cache.totalSeqLenTensor,
|
| 1025 |
-
});
|
| 1026 |
-
|
| 1027 |
-
if (step === 0) {
|
| 1028 |
-
depthSlices = outputs.depth_slices;
|
| 1029 |
-
}
|
| 1030 |
-
|
| 1031 |
-
pastKeys = outputs.new_keys;
|
| 1032 |
-
pastValues = outputs.new_values;
|
| 1033 |
|
|
|
|
| 1034 |
const logits = outputs.logits.data;
|
|
|
|
| 1035 |
|
|
|
|
| 1036 |
let token;
|
| 1037 |
-
if (temperature <= 0
|
|
|
|
| 1038 |
token = 0;
|
| 1039 |
let maxVal = logits[0];
|
| 1040 |
for (let j = 1; j < vocabSize; j++) {
|
| 1041 |
-
if (logits[j] > maxVal) {
|
|
|
|
|
|
|
|
|
|
| 1042 |
}
|
| 1043 |
} else {
|
|
|
|
| 1044 |
const scaledLogits = cache.scaledLogits;
|
| 1045 |
const indices = cache.indices;
|
| 1046 |
const probs = cache.probs;
|
| 1047 |
|
|
|
|
|
|
|
| 1048 |
for (let j = 0; j < vocabSize; j++) {
|
| 1049 |
scaledLogits[j] = logits[j] / temperature;
|
| 1050 |
indices[j] = j;
|
| 1051 |
}
|
| 1052 |
|
| 1053 |
-
// Partial
|
| 1054 |
for (let j = 0; j < topK; j++) {
|
| 1055 |
let maxIdx = j;
|
| 1056 |
for (let k = j + 1; k < vocabSize; k++) {
|
|
@@ -1058,6 +1081,7 @@ export class AudioModel {
|
|
| 1058 |
maxIdx = k;
|
| 1059 |
}
|
| 1060 |
}
|
|
|
|
| 1061 |
const tmp = indices[j];
|
| 1062 |
indices[j] = indices[maxIdx];
|
| 1063 |
indices[maxIdx] = tmp;
|
|
@@ -1074,18 +1098,25 @@ export class AudioModel {
|
|
| 1074 |
probs[j] /= sumExp;
|
| 1075 |
}
|
| 1076 |
|
| 1077 |
-
// Sample
|
| 1078 |
const r = Math.random();
|
| 1079 |
let cumsum = 0;
|
| 1080 |
-
token = indices[topK - 1];
|
| 1081 |
for (let j = 0; j < topK; j++) {
|
| 1082 |
cumsum += probs[j];
|
| 1083 |
-
if (r < cumsum) {
|
|
|
|
|
|
|
|
|
|
| 1084 |
}
|
| 1085 |
}
|
| 1086 |
|
| 1087 |
codes.push(token);
|
| 1088 |
prevToken = token;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1089 |
}
|
| 1090 |
|
| 1091 |
return codes;
|
|
@@ -1290,15 +1321,12 @@ export class AudioModel {
|
|
| 1290 |
throw new Error('Vocoder not loaded - required for interleaved mode');
|
| 1291 |
}
|
| 1292 |
|
| 1293 |
-
// Timing accumulators
|
| 1294 |
-
// - lfmText/lfmAudio: self.lfm() decoder calls (text vs audio steps)
|
| 1295 |
-
// - depthformer: self._sample_audio_frame() — depth_linear + 8× depthformer
|
| 1296 |
-
// - audioEmbed: self.audio_embedding(...).sum() — feedback embedding
|
| 1297 |
let timeAudioEncode = 0;
|
| 1298 |
let timePrefill = 0;
|
| 1299 |
-
let
|
| 1300 |
-
let
|
| 1301 |
-
let
|
| 1302 |
let timeAudioEmbed = 0;
|
| 1303 |
|
| 1304 |
// 1. Compute mel spectrogram and encode audio
|
|
@@ -1404,37 +1432,36 @@ export class AudioModel {
|
|
| 1404 |
|
| 1405 |
const startTime = performance.now();
|
| 1406 |
|
| 1407 |
-
|
| 1408 |
-
|
| 1409 |
-
let step = 0;
|
| 1410 |
-
for (; step < maxNewTokens; step++) {
|
| 1411 |
modalityLeft--;
|
| 1412 |
|
| 1413 |
if (inAudioMode) {
|
| 1414 |
-
//
|
| 1415 |
const hiddenData = hiddenStates.data;
|
| 1416 |
const seqLen = hiddenStates.dims[1];
|
| 1417 |
const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
|
| 1418 |
|
| 1419 |
tStep = performance.now();
|
| 1420 |
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
|
| 1421 |
-
|
| 1422 |
|
| 1423 |
// Switch back to text after N audio frames (if text not done)
|
| 1424 |
if (modalityLeft <= 0 && !textDone) {
|
| 1425 |
-
log(`→ AUDIO→TEXT (after ${INTERLEAVED_N_AUDIO} audio frames, ${audioCodes.length} total)`);
|
| 1426 |
inAudioMode = false;
|
| 1427 |
modalityLeft = INTERLEAVED_N_TEXT;
|
| 1428 |
}
|
| 1429 |
|
| 1430 |
// Check for end of audio - first codebook == 2048 (matching liquid-audio)
|
| 1431 |
if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
|
| 1432 |
-
log(
|
|
|
|
| 1433 |
for (let i = 0; i < NUM_CODEBOOKS; i++) {
|
| 1434 |
frameCodes[i] = END_OF_AUDIO_TOKEN;
|
| 1435 |
}
|
| 1436 |
inAudioMode = false;
|
|
|
|
| 1437 |
} else {
|
|
|
|
| 1438 |
const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
|
| 1439 |
audioCodes.push(clampedFrame);
|
| 1440 |
|
|
@@ -1443,15 +1470,16 @@ export class AudioModel {
|
|
| 1443 |
}
|
| 1444 |
|
| 1445 |
if (audioCodes.length % 50 === 0) {
|
| 1446 |
-
log(`
|
| 1447 |
}
|
| 1448 |
}
|
| 1449 |
|
| 1450 |
-
//
|
| 1451 |
tStep = performance.now();
|
| 1452 |
const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
|
| 1453 |
const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
|
| 1454 |
|
|
|
|
| 1455 |
const summedEmbeds = await this.getAudioEmbedding(audioTokens);
|
| 1456 |
timeAudioEmbed += performance.now() - tStep;
|
| 1457 |
|
|
@@ -1460,13 +1488,14 @@ export class AudioModel {
|
|
| 1460 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1461 |
tStep = performance.now();
|
| 1462 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1463 |
-
|
| 1464 |
this.updateCache(this.cache, outputs);
|
| 1465 |
|
| 1466 |
} else {
|
| 1467 |
-
//
|
| 1468 |
const logitsData = logits.data;
|
| 1469 |
const seqLen = logits.dims[1];
|
|
|
|
| 1470 |
const lastLogits = new Float32Array(this.vocabSize);
|
| 1471 |
const offset = (seqLen - 1) * this.vocabSize;
|
| 1472 |
for (let i = 0; i < this.vocabSize; i++) {
|
|
@@ -1476,19 +1505,18 @@ export class AudioModel {
|
|
| 1476 |
|
| 1477 |
// Check for end of turn
|
| 1478 |
if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
|
| 1479 |
-
log(
|
| 1480 |
break;
|
| 1481 |
}
|
| 1482 |
|
| 1483 |
// Check for <|text_end|> token (130)
|
| 1484 |
if (token === SPECIAL_TOKENS.TEXT_END) {
|
| 1485 |
-
log(
|
| 1486 |
textDone = true;
|
| 1487 |
}
|
| 1488 |
|
| 1489 |
// Switch to audio after N text tokens OR text_end
|
| 1490 |
if (modalityLeft <= 0 || textDone) {
|
| 1491 |
-
log(`→ TEXT→AUDIO${textDone ? ' (text_done)' : ''} at step ${step}`);
|
| 1492 |
inAudioMode = true;
|
| 1493 |
modalityLeft = INTERLEAVED_N_AUDIO;
|
| 1494 |
}
|
|
@@ -1500,19 +1528,18 @@ export class AudioModel {
|
|
| 1500 |
onToken(decodedText, token);
|
| 1501 |
}
|
| 1502 |
|
| 1503 |
-
//
|
| 1504 |
const nextEmbeds = this.getTextEmbeddings([token]);
|
| 1505 |
currentLen++;
|
| 1506 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1507 |
tStep = performance.now();
|
| 1508 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1509 |
-
|
| 1510 |
this.updateCache(this.cache, outputs);
|
| 1511 |
}
|
| 1512 |
}
|
| 1513 |
|
| 1514 |
// 5. Feed <|im_end|> token to close assistant turn in cache
|
| 1515 |
-
|
| 1516 |
const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
|
| 1517 |
currentLen++;
|
| 1518 |
const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
|
@@ -1523,11 +1550,14 @@ export class AudioModel {
|
|
| 1523 |
// Decode with skip_special_tokens to clean up special tokens like <|text_end|>
|
| 1524 |
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
|
| 1525 |
|
| 1526 |
-
|
| 1527 |
-
|
| 1528 |
-
log(`
|
| 1529 |
-
log(`
|
| 1530 |
-
log(`
|
|
|
|
|
|
|
|
|
|
| 1531 |
|
| 1532 |
return { text, audioCodes };
|
| 1533 |
}
|
|
@@ -1568,14 +1598,11 @@ export class AudioModel {
|
|
| 1568 |
throw new Error('Vocoder not loaded - required for interleaved mode');
|
| 1569 |
}
|
| 1570 |
|
| 1571 |
-
// Timing accumulators
|
| 1572 |
-
// - lfmText/lfmAudio: self.lfm() decoder calls (text vs audio steps)
|
| 1573 |
-
// - depthformer: self._sample_audio_frame() — depth_linear + 8× depthformer
|
| 1574 |
-
// - audioEmbed: self.audio_embedding(...).sum() — feedback embedding
|
| 1575 |
let timePrefill = 0;
|
| 1576 |
-
let
|
| 1577 |
-
let
|
| 1578 |
-
let
|
| 1579 |
let timeAudioEmbed = 0;
|
| 1580 |
let tStep;
|
| 1581 |
|
|
@@ -1640,18 +1667,17 @@ export class AudioModel {
|
|
| 1640 |
|
| 1641 |
tStep = performance.now();
|
| 1642 |
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
|
| 1643 |
-
|
| 1644 |
|
| 1645 |
// Switch back to text after N audio frames (if text not done)
|
| 1646 |
if (modalityLeft <= 0 && !textDone) {
|
| 1647 |
-
log(`→ AUDIO→TEXT (after ${INTERLEAVED_N_AUDIO} audio frames, ${audioCodes.length} total)`);
|
| 1648 |
inAudioMode = false;
|
| 1649 |
modalityLeft = INTERLEAVED_N_TEXT;
|
| 1650 |
}
|
| 1651 |
|
| 1652 |
// Check for end of audio
|
| 1653 |
if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
|
| 1654 |
-
log(
|
| 1655 |
for (let i = 0; i < NUM_CODEBOOKS; i++) {
|
| 1656 |
frameCodes[i] = END_OF_AUDIO_TOKEN;
|
| 1657 |
}
|
|
@@ -1665,7 +1691,7 @@ export class AudioModel {
|
|
| 1665 |
}
|
| 1666 |
|
| 1667 |
if (audioCodes.length % 50 === 0) {
|
| 1668 |
-
log(`
|
| 1669 |
}
|
| 1670 |
}
|
| 1671 |
|
|
@@ -1681,7 +1707,7 @@ export class AudioModel {
|
|
| 1681 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1682 |
tStep = performance.now();
|
| 1683 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1684 |
-
|
| 1685 |
this.updateCache(this.cache, outputs);
|
| 1686 |
|
| 1687 |
} else {
|
|
@@ -1697,19 +1723,18 @@ export class AudioModel {
|
|
| 1697 |
|
| 1698 |
// Check for end of turn
|
| 1699 |
if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
|
| 1700 |
-
log(
|
| 1701 |
break;
|
| 1702 |
}
|
| 1703 |
|
| 1704 |
// Check for <|text_end|> token
|
| 1705 |
if (token === SPECIAL_TOKENS.TEXT_END) {
|
| 1706 |
-
log(
|
| 1707 |
textDone = true;
|
| 1708 |
}
|
| 1709 |
|
| 1710 |
// Switch to audio after N text tokens OR text_end
|
| 1711 |
if (modalityLeft <= 0 || textDone) {
|
| 1712 |
-
log(`→ TEXT→AUDIO${textDone ? ' (text_done)' : ''} at step ${step}`);
|
| 1713 |
inAudioMode = true;
|
| 1714 |
modalityLeft = INTERLEAVED_N_AUDIO;
|
| 1715 |
}
|
|
@@ -1727,7 +1752,7 @@ export class AudioModel {
|
|
| 1727 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1728 |
tStep = performance.now();
|
| 1729 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1730 |
-
|
| 1731 |
this.updateCache(this.cache, outputs);
|
| 1732 |
}
|
| 1733 |
}
|
|
@@ -1742,9 +1767,13 @@ export class AudioModel {
|
|
| 1742 |
|
| 1743 |
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
|
| 1744 |
|
| 1745 |
-
log(
|
| 1746 |
-
log(`
|
| 1747 |
-
log(`
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1748 |
|
| 1749 |
return { text, audioCodes };
|
| 1750 |
}
|
|
@@ -1874,8 +1903,7 @@ export class AudioModel {
|
|
| 1874 |
}
|
| 1875 |
|
| 1876 |
const decodeStart = performance.now();
|
| 1877 |
-
|
| 1878 |
-
log(`Audio decode: ${audioCodes.length} frames → waveform`);
|
| 1879 |
|
| 1880 |
// ISTFT parameters (fixed for this model)
|
| 1881 |
const nFft = 1280;
|
|
@@ -1883,7 +1911,7 @@ export class AudioModel {
|
|
| 1883 |
const winLength = 1280;
|
| 1884 |
const nFftBins = nFft / 2 + 1;
|
| 1885 |
|
| 1886 |
-
//
|
| 1887 |
const T = audioCodes.length;
|
| 1888 |
const codesTransposed = new BigInt64Array(8 * T);
|
| 1889 |
for (let t = 0; t < T; t++) {
|
|
@@ -1892,18 +1920,18 @@ export class AudioModel {
|
|
| 1892 |
}
|
| 1893 |
}
|
| 1894 |
|
| 1895 |
-
// Run detokenizer
|
| 1896 |
const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
|
| 1897 |
const detokStart = performance.now();
|
| 1898 |
const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
|
| 1899 |
const stftFeatures = detokOutputs.stft_features;
|
| 1900 |
-
|
| 1901 |
-
const detokEnd = performance.now();
|
| 1902 |
-
log(` Detokenizer: [1,8,${T}] → [1,${actualT},1282] in ${(detokEnd - detokStart).toFixed(0)}ms`);
|
| 1903 |
|
| 1904 |
-
//
|
| 1905 |
const stftData = stftFeatures.data;
|
|
|
|
| 1906 |
|
|
|
|
| 1907 |
const complexStft = new Array(nFftBins);
|
| 1908 |
for (let f = 0; f < nFftBins; f++) {
|
| 1909 |
complexStft[f] = new Array(actualT);
|
|
@@ -1911,23 +1939,23 @@ export class AudioModel {
|
|
| 1911 |
const logMag = stftData[t * 1282 + f];
|
| 1912 |
const angle = stftData[t * 1282 + nFftBins + f];
|
| 1913 |
const mag = Math.exp(logMag);
|
|
|
|
| 1914 |
complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
|
| 1915 |
}
|
| 1916 |
}
|
| 1917 |
|
| 1918 |
-
// ISTFT
|
| 1919 |
-
const pad = (winLength - hopLength) / 2;
|
| 1920 |
const istftStart = performance.now();
|
| 1921 |
const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
|
| 1922 |
-
|
| 1923 |
-
log(` ISTFT: ${actualT} frames → ${waveform.length} samples in ${(istftEnd - istftStart).toFixed(0)}ms`);
|
| 1924 |
|
| 1925 |
-
// Find max/min
|
| 1926 |
let waveMax = -Infinity, waveMin = Infinity;
|
| 1927 |
for (let i = 0; i < waveform.length; i++) {
|
| 1928 |
if (waveform[i] > waveMax) waveMax = waveform[i];
|
| 1929 |
if (waveform[i] < waveMin) waveMin = waveform[i];
|
| 1930 |
}
|
|
|
|
| 1931 |
|
| 1932 |
// Check for invalid values
|
| 1933 |
if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
|
|
@@ -1935,7 +1963,7 @@ export class AudioModel {
|
|
| 1935 |
return new Float32Array(0);
|
| 1936 |
}
|
| 1937 |
|
| 1938 |
-
// Normalize
|
| 1939 |
let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
|
| 1940 |
if (maxVal > 0) {
|
| 1941 |
for (let i = 0; i < waveform.length; i++) {
|
|
@@ -1945,8 +1973,7 @@ export class AudioModel {
|
|
| 1945 |
console.warn('ISTFT produced all-zero waveform');
|
| 1946 |
}
|
| 1947 |
|
| 1948 |
-
|
| 1949 |
-
log(`Audio decode complete: ${totalDecodeTime.toFixed(0)}ms total (detok=${(detokEnd - detokStart).toFixed(0)}ms, istft=${(istftEnd - istftStart).toFixed(0)}ms) → ${waveform.length} samples, ${(waveform.length / 24000).toFixed(2)}s @ 24kHz`);
|
| 1950 |
return waveform;
|
| 1951 |
}
|
| 1952 |
|
|
|
|
| 397 |
};
|
| 398 |
|
| 399 |
// Helper to load ONNX model with external data
|
| 400 |
+
const loadOnnxWithExternalData = async (name, progress, quantSuffix = null, extraOptions = {}) => {
|
| 401 |
const suffix = quantSuffix ? `_${quantSuffix}` : '';
|
| 402 |
const fileName = `${name}${suffix}`;
|
| 403 |
report('loading', progress, `${fileName}.onnx`);
|
|
|
|
| 405 |
const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
|
| 406 |
const fetchOptions = { mode: 'cors', credentials: 'omit' };
|
| 407 |
|
| 408 |
+
console.log(`Loading ${fileName}...`);
|
|
|
|
| 409 |
|
| 410 |
+
const sessionOptions = { executionProviders, ...extraOptions };
|
| 411 |
|
| 412 |
const onnxResponse = await fetchWithCache(onnxPath, fetchOptions);
|
| 413 |
if (!onnxResponse.ok) {
|
|
|
|
| 477 |
}
|
| 478 |
return { preferredOutputLocation: loc };
|
| 479 |
})() : {};
|
| 480 |
+
this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder, decoderOpts);
|
| 481 |
|
| 482 |
// Load embed_tokens weight for text embedding lookup
|
| 483 |
report('loading', 30, 'embed_tokens');
|
|
|
|
| 503 |
console.warn('Audio detokenizer not available:', e);
|
| 504 |
}
|
| 505 |
|
| 506 |
+
// Load vocoder (for TTS)
|
| 507 |
// On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps
|
| 508 |
try {
|
| 509 |
const vocoderOpts = device === 'webgpu'
|
| 510 |
? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer' } }
|
| 511 |
: {};
|
| 512 |
+
this.vocoderSession = await loadOnnxWithExternalData('vocoder_depthformer', 95, quantConfig.vocoder, vocoderOpts);
|
|
|
|
|
|
|
| 513 |
} catch (e) {
|
| 514 |
console.warn('Vocoder not available:', e);
|
| 515 |
}
|
|
|
|
| 951 |
return '[Text generation requires full embedding support - model loaded successfully]';
|
| 952 |
}
|
| 953 |
|
| 954 |
+
/**
|
| 955 |
+
* Initialize reusable vocoder tensors to reduce allocation overhead
|
| 956 |
+
*/
|
| 957 |
_initVocoderCache() {
|
| 958 |
if (this._vocoderCache) return;
|
| 959 |
|
| 960 |
+
const numLayers = 6;
|
| 961 |
+
const numKvHeads = 8;
|
| 962 |
+
const headDim = 32;
|
| 963 |
+
|
| 964 |
+
// Pre-allocate data arrays
|
| 965 |
const stepIdxData = new BigInt64Array(1);
|
| 966 |
const prevTokenData = new BigInt64Array(1);
|
|
|
|
|
|
|
| 967 |
|
| 968 |
+
// Pre-allocate tensors that can be reused
|
| 969 |
this._vocoderCache = {
|
| 970 |
+
hiddenTensor: null, // Created per-call since hiddenState changes
|
| 971 |
stepIdxData,
|
| 972 |
prevTokenData,
|
| 973 |
+
// Pre-create reusable tensors (ONNX Runtime reads from the data array)
|
|
|
|
| 974 |
stepIdxTensor: new ort.Tensor('int64', stepIdxData, []),
|
| 975 |
prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]),
|
| 976 |
+
emptyKeysData: new Float32Array(0),
|
| 977 |
+
emptyValuesData: new Float32Array(0),
|
| 978 |
+
// Reusable sampling arrays
|
| 979 |
+
scaledLogits: new Float32Array(2049), // codebook vocab size
|
| 980 |
+
indices: new Uint16Array(2049), // Use typed array for faster reset
|
| 981 |
+
probs: new Float32Array(64), // top-k size
|
|
|
|
| 982 |
};
|
| 983 |
+
|
| 984 |
+
// Initialize indices
|
| 985 |
+
for (let i = 0; i < 2049; i++) {
|
| 986 |
+
this._vocoderCache.indices[i] = i;
|
| 987 |
+
}
|
| 988 |
}
|
| 989 |
|
| 990 |
/**
|
| 991 |
+
* Sample audio codes using vocoder depthformer
|
| 992 |
+
* Optimized to reduce tensor creation overhead
|
| 993 |
* @param {Float32Array} hiddenState - [hidden_size] hidden state
|
| 994 |
* @param {number} temperature - Sampling temperature
|
| 995 |
+
* @param {number} topK - Top-k sampling
|
| 996 |
* @returns {number[]} - 8 codebook values
|
| 997 |
*/
|
| 998 |
async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) {
|
| 999 |
+
if (!this.vocoderSession) {
|
| 1000 |
+
throw new Error('Vocoder not loaded');
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
// Initialize cache on first call
|
| 1004 |
this._initVocoderCache();
|
| 1005 |
const cache = this._vocoderCache;
|
| 1006 |
|
| 1007 |
const numCodebooks = 8;
|
| 1008 |
const numLayers = 6;
|
| 1009 |
+
const numKvHeads = 8;
|
| 1010 |
const headDim = 32;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
|
| 1012 |
const codes = [];
|
| 1013 |
let prevToken = 0;
|
| 1014 |
|
| 1015 |
+
// Create hidden state tensor (must be new since data changes)
|
| 1016 |
+
const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]);
|
| 1017 |
+
|
| 1018 |
+
// Initialize empty KV cache
|
| 1019 |
+
let pastKeys = new ort.Tensor(
|
| 1020 |
+
'float32',
|
| 1021 |
+
cache.emptyKeysData,
|
| 1022 |
+
[numLayers, 1, 0, numKvHeads, headDim]
|
| 1023 |
+
);
|
| 1024 |
+
let pastValues = new ort.Tensor(
|
| 1025 |
+
'float32',
|
| 1026 |
+
cache.emptyValuesData,
|
| 1027 |
+
[numLayers, 1, 0, numKvHeads, headDim]
|
| 1028 |
+
);
|
| 1029 |
+
|
| 1030 |
+
// Reuse step_idx and prev_token tensors by updating their data
|
| 1031 |
+
cache.stepIdxData[0] = 0n;
|
| 1032 |
+
cache.prevTokenData[0] = 0n;
|
| 1033 |
+
|
| 1034 |
+
for (let i = 0; i < numCodebooks; i++) {
|
| 1035 |
+
// Update mutable tensor data (tensor objects reuse the underlying data arrays)
|
| 1036 |
+
cache.stepIdxData[0] = BigInt(i);
|
| 1037 |
cache.prevTokenData[0] = BigInt(prevToken);
|
|
|
|
|
|
|
| 1038 |
|
| 1039 |
+
const feeds = {
|
| 1040 |
hidden_states: hiddenTensor,
|
|
|
|
| 1041 |
step_idx: cache.stepIdxTensor,
|
| 1042 |
prev_token: cache.prevTokenTensor,
|
| 1043 |
past_keys: pastKeys,
|
| 1044 |
past_values: pastValues,
|
| 1045 |
+
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1046 |
|
| 1047 |
+
const outputs = await this.vocoderSession.run(feeds);
|
| 1048 |
const logits = outputs.logits.data;
|
| 1049 |
+
const vocabSize = logits.length;
|
| 1050 |
|
| 1051 |
+
// Sample with temperature and top-k (reusing cached arrays)
|
| 1052 |
let token;
|
| 1053 |
+
if (temperature <= 0) {
|
| 1054 |
+
// Greedy
|
| 1055 |
token = 0;
|
| 1056 |
let maxVal = logits[0];
|
| 1057 |
for (let j = 1; j < vocabSize; j++) {
|
| 1058 |
+
if (logits[j] > maxVal) {
|
| 1059 |
+
maxVal = logits[j];
|
| 1060 |
+
token = j;
|
| 1061 |
+
}
|
| 1062 |
}
|
| 1063 |
} else {
|
| 1064 |
+
// Top-k sampling with reused arrays
|
| 1065 |
const scaledLogits = cache.scaledLogits;
|
| 1066 |
const indices = cache.indices;
|
| 1067 |
const probs = cache.probs;
|
| 1068 |
|
| 1069 |
+
// Scale logits by temperature and find top-k in single pass
|
| 1070 |
+
// Use partial selection sort (O(k*n) which is fast for small k)
|
| 1071 |
for (let j = 0; j < vocabSize; j++) {
|
| 1072 |
scaledLogits[j] = logits[j] / temperature;
|
| 1073 |
indices[j] = j;
|
| 1074 |
}
|
| 1075 |
|
| 1076 |
+
// Partial sort to get top-k
|
| 1077 |
for (let j = 0; j < topK; j++) {
|
| 1078 |
let maxIdx = j;
|
| 1079 |
for (let k = j + 1; k < vocabSize; k++) {
|
|
|
|
| 1081 |
maxIdx = k;
|
| 1082 |
}
|
| 1083 |
}
|
| 1084 |
+
// Swap
|
| 1085 |
const tmp = indices[j];
|
| 1086 |
indices[j] = indices[maxIdx];
|
| 1087 |
indices[maxIdx] = tmp;
|
|
|
|
| 1098 |
probs[j] /= sumExp;
|
| 1099 |
}
|
| 1100 |
|
| 1101 |
+
// Sample
|
| 1102 |
const r = Math.random();
|
| 1103 |
let cumsum = 0;
|
| 1104 |
+
token = indices[topK - 1]; // Default to last
|
| 1105 |
for (let j = 0; j < topK; j++) {
|
| 1106 |
cumsum += probs[j];
|
| 1107 |
+
if (r < cumsum) {
|
| 1108 |
+
token = indices[j];
|
| 1109 |
+
break;
|
| 1110 |
+
}
|
| 1111 |
}
|
| 1112 |
}
|
| 1113 |
|
| 1114 |
codes.push(token);
|
| 1115 |
prevToken = token;
|
| 1116 |
+
|
| 1117 |
+
// Update KV cache
|
| 1118 |
+
pastKeys = outputs.new_keys;
|
| 1119 |
+
pastValues = outputs.new_values;
|
| 1120 |
}
|
| 1121 |
|
| 1122 |
return codes;
|
|
|
|
| 1321 |
throw new Error('Vocoder not loaded - required for interleaved mode');
|
| 1322 |
}
|
| 1323 |
|
| 1324 |
+
// Timing accumulators
|
|
|
|
|
|
|
|
|
|
| 1325 |
let timeAudioEncode = 0;
|
| 1326 |
let timePrefill = 0;
|
| 1327 |
+
let timeTextDecode = 0;
|
| 1328 |
+
let timeAudioDecode = 0;
|
| 1329 |
+
let timeVocoder = 0;
|
| 1330 |
let timeAudioEmbed = 0;
|
| 1331 |
|
| 1332 |
// 1. Compute mel spectrogram and encode audio
|
|
|
|
| 1432 |
|
| 1433 |
const startTime = performance.now();
|
| 1434 |
|
| 1435 |
+
for (let step = 0; step < maxNewTokens; step++) {
|
|
|
|
|
|
|
|
|
|
| 1436 |
modalityLeft--;
|
| 1437 |
|
| 1438 |
if (inAudioMode) {
|
| 1439 |
+
// Generate audio frame using depthformer
|
| 1440 |
const hiddenData = hiddenStates.data;
|
| 1441 |
const seqLen = hiddenStates.dims[1];
|
| 1442 |
const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize);
|
| 1443 |
|
| 1444 |
tStep = performance.now();
|
| 1445 |
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
|
| 1446 |
+
timeVocoder += performance.now() - tStep;
|
| 1447 |
|
| 1448 |
// Switch back to text after N audio frames (if text not done)
|
| 1449 |
if (modalityLeft <= 0 && !textDone) {
|
|
|
|
| 1450 |
inAudioMode = false;
|
| 1451 |
modalityLeft = INTERLEAVED_N_TEXT;
|
| 1452 |
}
|
| 1453 |
|
| 1454 |
// Check for end of audio - first codebook == 2048 (matching liquid-audio)
|
| 1455 |
if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
|
| 1456 |
+
log(`End of audio at step ${step}`);
|
| 1457 |
+
// Set all codes to 2048 (matching liquid-audio)
|
| 1458 |
for (let i = 0; i < NUM_CODEBOOKS; i++) {
|
| 1459 |
frameCodes[i] = END_OF_AUDIO_TOKEN;
|
| 1460 |
}
|
| 1461 |
inAudioMode = false;
|
| 1462 |
+
// Don't save this frame, but still feed it back
|
| 1463 |
} else {
|
| 1464 |
+
// Save valid frame (clamped to 0-2047)
|
| 1465 |
const clampedFrame = frameCodes.map(c => Math.min(c, 2047));
|
| 1466 |
audioCodes.push(clampedFrame);
|
| 1467 |
|
|
|
|
| 1470 |
}
|
| 1471 |
|
| 1472 |
if (audioCodes.length % 50 === 0) {
|
| 1473 |
+
log(`Generated ${audioCodes.length} audio frames`);
|
| 1474 |
}
|
| 1475 |
}
|
| 1476 |
|
| 1477 |
+
// Get embeddings for next step (always feed back, even for 2048 frames)
|
| 1478 |
tStep = performance.now();
|
| 1479 |
const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047));
|
| 1480 |
const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code);
|
| 1481 |
|
| 1482 |
+
// Get summed embeddings for all 8 codebooks
|
| 1483 |
const summedEmbeds = await this.getAudioEmbedding(audioTokens);
|
| 1484 |
timeAudioEmbed += performance.now() - tStep;
|
| 1485 |
|
|
|
|
| 1488 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1489 |
tStep = performance.now();
|
| 1490 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1491 |
+
timeAudioDecode += performance.now() - tStep;
|
| 1492 |
this.updateCache(this.cache, outputs);
|
| 1493 |
|
| 1494 |
} else {
|
| 1495 |
+
// Generate text token
|
| 1496 |
const logitsData = logits.data;
|
| 1497 |
const seqLen = logits.dims[1];
|
| 1498 |
+
// Get logits for last position - shape is [1, seq_len, vocab_size]
|
| 1499 |
const lastLogits = new Float32Array(this.vocabSize);
|
| 1500 |
const offset = (seqLen - 1) * this.vocabSize;
|
| 1501 |
for (let i = 0; i < this.vocabSize; i++) {
|
|
|
|
| 1505 |
|
| 1506 |
// Check for end of turn
|
| 1507 |
if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
|
| 1508 |
+
log(`End of turn at step ${step}`);
|
| 1509 |
break;
|
| 1510 |
}
|
| 1511 |
|
| 1512 |
// Check for <|text_end|> token (130)
|
| 1513 |
if (token === SPECIAL_TOKENS.TEXT_END) {
|
| 1514 |
+
log(`Text end at step ${step}`);
|
| 1515 |
textDone = true;
|
| 1516 |
}
|
| 1517 |
|
| 1518 |
// Switch to audio after N text tokens OR text_end
|
| 1519 |
if (modalityLeft <= 0 || textDone) {
|
|
|
|
| 1520 |
inAudioMode = true;
|
| 1521 |
modalityLeft = INTERLEAVED_N_AUDIO;
|
| 1522 |
}
|
|
|
|
| 1528 |
onToken(decodedText, token);
|
| 1529 |
}
|
| 1530 |
|
| 1531 |
+
// Get embedding for next step
|
| 1532 |
const nextEmbeds = this.getTextEmbeddings([token]);
|
| 1533 |
currentLen++;
|
| 1534 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1535 |
tStep = performance.now();
|
| 1536 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1537 |
+
timeTextDecode += performance.now() - tStep;
|
| 1538 |
this.updateCache(this.cache, outputs);
|
| 1539 |
}
|
| 1540 |
}
|
| 1541 |
|
| 1542 |
// 5. Feed <|im_end|> token to close assistant turn in cache
|
|
|
|
| 1543 |
const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]);
|
| 1544 |
currentLen++;
|
| 1545 |
const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
|
|
|
| 1550 |
// Decode with skip_special_tokens to clean up special tokens like <|text_end|>
|
| 1551 |
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
|
| 1552 |
|
| 1553 |
+
// Print timing summary
|
| 1554 |
+
log(`=== Summary ===`);
|
| 1555 |
+
log(` Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, Prefill: ${timePrefill.toFixed(0)}ms`);
|
| 1556 |
+
log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
|
| 1557 |
+
log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
|
| 1558 |
+
log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
|
| 1559 |
+
log(`Text: "${text}"`);
|
| 1560 |
+
log(`Cache seq_len: ${this.cacheSeqLen}`);
|
| 1561 |
|
| 1562 |
return { text, audioCodes };
|
| 1563 |
}
|
|
|
|
| 1598 |
throw new Error('Vocoder not loaded - required for interleaved mode');
|
| 1599 |
}
|
| 1600 |
|
| 1601 |
+
// Timing accumulators
|
|
|
|
|
|
|
|
|
|
| 1602 |
let timePrefill = 0;
|
| 1603 |
+
let timeTextDecode = 0;
|
| 1604 |
+
let timeAudioDecode = 0;
|
| 1605 |
+
let timeVocoder = 0;
|
| 1606 |
let timeAudioEmbed = 0;
|
| 1607 |
let tStep;
|
| 1608 |
|
|
|
|
| 1667 |
|
| 1668 |
tStep = performance.now();
|
| 1669 |
const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK);
|
| 1670 |
+
timeVocoder += performance.now() - tStep;
|
| 1671 |
|
| 1672 |
// Switch back to text after N audio frames (if text not done)
|
| 1673 |
if (modalityLeft <= 0 && !textDone) {
|
|
|
|
| 1674 |
inAudioMode = false;
|
| 1675 |
modalityLeft = INTERLEAVED_N_TEXT;
|
| 1676 |
}
|
| 1677 |
|
| 1678 |
// Check for end of audio
|
| 1679 |
if (frameCodes[0] === END_OF_AUDIO_TOKEN) {
|
| 1680 |
+
log(`End of audio at step ${step}`);
|
| 1681 |
for (let i = 0; i < NUM_CODEBOOKS; i++) {
|
| 1682 |
frameCodes[i] = END_OF_AUDIO_TOKEN;
|
| 1683 |
}
|
|
|
|
| 1691 |
}
|
| 1692 |
|
| 1693 |
if (audioCodes.length % 50 === 0) {
|
| 1694 |
+
log(`Generated ${audioCodes.length} audio frames`);
|
| 1695 |
}
|
| 1696 |
}
|
| 1697 |
|
|
|
|
| 1707 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1708 |
tStep = performance.now();
|
| 1709 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1710 |
+
timeAudioDecode += performance.now() - tStep;
|
| 1711 |
this.updateCache(this.cache, outputs);
|
| 1712 |
|
| 1713 |
} else {
|
|
|
|
| 1723 |
|
| 1724 |
// Check for end of turn
|
| 1725 |
if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) {
|
| 1726 |
+
log(`End of turn at step ${step}`);
|
| 1727 |
break;
|
| 1728 |
}
|
| 1729 |
|
| 1730 |
// Check for <|text_end|> token
|
| 1731 |
if (token === SPECIAL_TOKENS.TEXT_END) {
|
| 1732 |
+
log(`Text end at step ${step}`);
|
| 1733 |
textDone = true;
|
| 1734 |
}
|
| 1735 |
|
| 1736 |
// Switch to audio after N text tokens OR text_end
|
| 1737 |
if (modalityLeft <= 0 || textDone) {
|
|
|
|
| 1738 |
inAudioMode = true;
|
| 1739 |
modalityLeft = INTERLEAVED_N_AUDIO;
|
| 1740 |
}
|
|
|
|
| 1752 |
const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]);
|
| 1753 |
tStep = performance.now();
|
| 1754 |
({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache));
|
| 1755 |
+
timeTextDecode += performance.now() - tStep;
|
| 1756 |
this.updateCache(this.cache, outputs);
|
| 1757 |
}
|
| 1758 |
}
|
|
|
|
| 1767 |
|
| 1768 |
const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true });
|
| 1769 |
|
| 1770 |
+
log(`=== Summary ===`);
|
| 1771 |
+
log(` Prefill: ${timePrefill.toFixed(0)}ms`);
|
| 1772 |
+
log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`);
|
| 1773 |
+
log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`);
|
| 1774 |
+
log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`);
|
| 1775 |
+
log(`Text: "${text}"`);
|
| 1776 |
+
log(`Cache seq_len: ${this.cacheSeqLen}`);
|
| 1777 |
|
| 1778 |
return { text, audioCodes };
|
| 1779 |
}
|
|
|
|
| 1903 |
}
|
| 1904 |
|
| 1905 |
const decodeStart = performance.now();
|
| 1906 |
+
log(`Decoding ${audioCodes.length} audio frames...`);
|
|
|
|
| 1907 |
|
| 1908 |
// ISTFT parameters (fixed for this model)
|
| 1909 |
const nFft = 1280;
|
|
|
|
| 1911 |
const winLength = 1280;
|
| 1912 |
const nFftBins = nFft / 2 + 1;
|
| 1913 |
|
| 1914 |
+
// Stack codes: [T, 8] -> [8, T] and add batch -> [1, 8, T]
|
| 1915 |
const T = audioCodes.length;
|
| 1916 |
const codesTransposed = new BigInt64Array(8 * T);
|
| 1917 |
for (let t = 0; t < T; t++) {
|
|
|
|
| 1920 |
}
|
| 1921 |
}
|
| 1922 |
|
| 1923 |
+
// Run detokenizer: [1, 8, T] -> [1, T, 1282]
|
| 1924 |
const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]);
|
| 1925 |
const detokStart = performance.now();
|
| 1926 |
const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor });
|
| 1927 |
const stftFeatures = detokOutputs.stft_features;
|
| 1928 |
+
log(`Detokenizer: ${(performance.now() - detokStart).toFixed(0)}ms, STFT frames: ${stftFeatures.dims[1]}`);
|
|
|
|
|
|
|
| 1929 |
|
| 1930 |
+
// Get raw data - shape is [1, T, 1282], we need to skip batch dimension
|
| 1931 |
const stftData = stftFeatures.data;
|
| 1932 |
+
const actualT = stftFeatures.dims[1];
|
| 1933 |
|
| 1934 |
+
// Convert to complex STFT: [log_magnitude | angle] -> complex
|
| 1935 |
const complexStft = new Array(nFftBins);
|
| 1936 |
for (let f = 0; f < nFftBins; f++) {
|
| 1937 |
complexStft[f] = new Array(actualT);
|
|
|
|
| 1939 |
const logMag = stftData[t * 1282 + f];
|
| 1940 |
const angle = stftData[t * 1282 + nFftBins + f];
|
| 1941 |
const mag = Math.exp(logMag);
|
| 1942 |
+
// Store as [real, imag]
|
| 1943 |
complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)];
|
| 1944 |
}
|
| 1945 |
}
|
| 1946 |
|
| 1947 |
+
// ISTFT with 'same' padding
|
|
|
|
| 1948 |
const istftStart = performance.now();
|
| 1949 |
const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT);
|
| 1950 |
+
log(`ISTFT: ${(performance.now() - istftStart).toFixed(0)}ms`);
|
|
|
|
| 1951 |
|
| 1952 |
+
// Find max/min without spread operator (avoid stack overflow on large arrays)
|
| 1953 |
let waveMax = -Infinity, waveMin = Infinity;
|
| 1954 |
for (let i = 0; i < waveform.length; i++) {
|
| 1955 |
if (waveform[i] > waveMax) waveMax = waveform[i];
|
| 1956 |
if (waveform[i] < waveMin) waveMin = waveform[i];
|
| 1957 |
}
|
| 1958 |
+
log('ISTFT output - length:', waveform.length, 'max:', waveMax.toFixed(4), 'min:', waveMin.toFixed(4));
|
| 1959 |
|
| 1960 |
// Check for invalid values
|
| 1961 |
if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) {
|
|
|
|
| 1963 |
return new Float32Array(0);
|
| 1964 |
}
|
| 1965 |
|
| 1966 |
+
// Normalize to [-1, 1]
|
| 1967 |
let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin));
|
| 1968 |
if (maxVal > 0) {
|
| 1969 |
for (let i = 0; i < waveform.length; i++) {
|
|
|
|
| 1973 |
console.warn('ISTFT produced all-zero waveform');
|
| 1974 |
}
|
| 1975 |
|
| 1976 |
+
log(`Decoded audio: ${waveform.length} samples (${(waveform.length / 24000).toFixed(2)}s)`);
|
|
|
|
| 1977 |
return waveform;
|
| 1978 |
}
|
| 1979 |
|