Spaces:
Running
on
A100
Running
on
A100
File size: 104,896 Bytes
bf5e1fd 03f73c6 11860f1 bf5e1fd 4477394 bf5e1fd 4670365 bf5e1fd 376c43e 875a989 2b1ad1c 59f5a8a bf5e1fd c0934b3 59ce525 bf5e1fd 59ce525 c0934b3 59ce525 a161649 59ce525 bf5e1fd 59ce525 bf5e1fd 59ce525 bf5e1fd 11860f1 bf5e1fd b04b635 bf5e1fd b04b635 bf5e1fd 4a86c5f bf5e1fd 03f73c6 bf5e1fd c0934b3 bf5e1fd 11860f1 bf5e1fd 11860f1 bf5e1fd 11860f1 bf5e1fd 12bc51a def62ce 12bc51a bf5e1fd 11860f1 ba7469b 11860f1 4670365 11860f1 4670365 3092911 11860f1 c0934b3 9e64ac5 f41792a ba7469b a161649 1241c80 4477394 11860f1 4670365 11860f1 c0934b3 11860f1 c0934b3 11860f1 3092911 4670365 c0934b3 4670365 3092911 4670365 4477394 11860f1 0659e3b 11860f1 0659e3b 11860f1 0659e3b 11860f1 c0934b3 11860f1 4670365 11860f1 4670365 11860f1 4670365 11860f1 3092911 4670365 11860f1 4670365 4477394 11860f1 4670365 3092911 4670365 3092911 4670365 3092911 4670365 11860f1 7f5c13a 4477394 7f5c13a 4477394 c9570f3 4477394 4670365 c0934b3 ba7469b c0934b3 9e64ac5 51dc2aa 03f73c6 1241c80 11860f1 03f73c6 4477394 c0934b3 4477394 9e64ac5 51dc2aa 1241c80 11860f1 03f73c6 11860f1 03f73c6 c0934b3 03f73c6 c0934b3 03f73c6 11860f1 4477394 7f5c13a 4477394 0659e3b 11860f1 4477394 11860f1 03f73c6 11860f1 0659e3b 4477394 11860f1 4477394 7f5c13a 4477394 7f5c13a 1241c80 a161649 4477394 7f5c13a 4477394 7f5c13a 4477394 0659e3b 4477394 03f73c6 4477394 11860f1 4477394 11860f1 4477394 11860f1 03f73c6 11860f1 03f73c6 4477394 11860f1 0659e3b 4477394 03f73c6 11860f1 0659e3b 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 0659e3b 11860f1 0659e3b 11860f1 03f73c6 0659e3b 11860f1 0659e3b 11860f1 03f73c6 11860f1 4477394 bf5e1fd 4670365 4477394 4670365 4477394 4670365 4477394 11860f1 4477394 4670365 4477394 4670365 4477394 4670365 1241c80 4670365 4477394 11860f1 4477394 2b1ad1c 4477394 2b1ad1c 4477394 2b1ad1c 4477394 a161649 4477394 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c 2eced9f 2b1ad1c 4a86c5f 2eced9f 4a86c5f 2eced9f 4a86c5f 2eced9f 4a86c5f 2b1ad1c def62ce 2b1ad1c 4670365 3092911 ba7469b 4670365 9e64ac5 4477394 3092911 ba7469b 4670365 bf5e1fd 4670365 bf5e1fd 4670365 9e64ac5 f41792a 1241c80 a161649 4477394 4670365 11860f1 4670365 3092911 9e64ac5 f41792a ba7469b a161649 1241c80 4477394 4670365 11860f1 4670365 3092911 9e64ac5 f41792a ba7469b a161649 1241c80 4477394 4670365 bf5e1fd 59f5a8a 11860f1 59f5a8a 11860f1 59f5a8a 11860f1 59f5a8a 11860f1 59f5a8a 11860f1 59f5a8a 11860f1 59f5a8a 4670365 376c43e 4670365 376c43e 59f5a8a 376c43e 4670365 59f5a8a 4670365 376c43e 4670365 11860f1 4670365 376c43e 11860f1 376c43e 4670365 376c43e 4670365 376c43e 59f5a8a 4670365 11860f1 376c43e 11860f1 376c43e 59f5a8a 11860f1 59f5a8a 4670365 11860f1 4670365 376c43e 4670365 376c43e 4670365 376c43e 4670365 1241c80 4670365 1241c80 4670365 1241c80 4670365 1241c80 4670365 1241c80 4670365 376c43e bf5e1fd a161649 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 |
"""
5Hz LM (Language Model) Handler
Handles all LM-related operations including initialization and generation
"""
import os
import traceback
import time
import random
from typing import Optional, Dict, Any, Tuple, List, Union
from contextlib import contextmanager
import yaml
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.streamers import BaseStreamer
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
)
from acestep.constrained_logits_processor import MetadataConstrainedLogitsProcessor
from acestep.constants import DEFAULT_LM_INSTRUCTION, DEFAULT_LM_UNDERSTAND_INSTRUCTION, DEFAULT_LM_INSPIRED_INSTRUCTION, DEFAULT_LM_REWRITE_INSTRUCTION
class LLMHandler:
"""5Hz LM Handler for audio code generation"""
STOP_REASONING_TAG = "</think>"
# HuggingFace Space environment detection
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
def __init__(self, persistent_storage_path: Optional[str] = None):
"""Initialize LLMHandler with default values"""
self.llm = None
self.llm_tokenizer = None
self.llm_initialized = False
self.llm_backend = None
self.max_model_len = 4096
self.device = "cpu"
self.dtype = torch.float32
self.offload_to_cpu = False
# HuggingFace Space persistent storage support
if persistent_storage_path is None and self.IS_HUGGINGFACE_SPACE:
persistent_storage_path = "/data"
self.persistent_storage_path = persistent_storage_path
# Shared constrained decoding processor
self.constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None
# Shared HuggingFace model for perplexity calculation
self._hf_model_for_scoring = None
def _get_checkpoint_dir(self) -> str:
"""Get checkpoint directory, prioritizing persistent storage"""
if self.persistent_storage_path:
return os.path.join(self.persistent_storage_path, "checkpoints")
current_file = os.path.abspath(__file__)
project_root = os.path.dirname(os.path.dirname(current_file))
return os.path.join(project_root, "checkpoints")
def get_available_5hz_lm_models(self) -> List[str]:
"""Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
checkpoint_dir = self._get_checkpoint_dir()
models = []
if os.path.exists(checkpoint_dir):
for item in os.listdir(checkpoint_dir):
item_path = os.path.join(checkpoint_dir, item)
if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
models.append(item)
models.sort()
return models
def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> Tuple[float, bool]:
"""Get GPU memory utilization ratio"""
try:
device = torch.device("cuda:0")
total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory
allocated_mem_bytes = torch.cuda.memory_allocated(device)
reserved_mem_bytes = torch.cuda.memory_reserved(device)
total_gpu = total_gpu_mem_bytes / 1024**3
low_gpu_memory_mode = False
if total_gpu < minimal_gpu:
minimal_gpu = 0.5 * total_gpu
low_gpu_memory_mode = True
allocated_gpu = allocated_mem_bytes / 1024**3
reserved_gpu = reserved_mem_bytes / 1024**3
available_gpu = total_gpu - reserved_gpu
if available_gpu >= minimal_gpu:
ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu))
else:
ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
return ratio, low_gpu_memory_mode
except Exception as e:
return 0.9, False
def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool:
"""Check if negative prompt is meaningful (not default/empty)"""
return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList:
"""Build logits processor list with repetition penalty if needed"""
logits_processor = LogitsProcessorList()
if repetition_penalty != 1.0:
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
return logits_processor
def _setup_constrained_processor(
self,
use_constrained_decoding: bool,
constrained_decoding_debug: bool,
target_duration: Optional[float],
user_metadata: Optional[Dict[str, Optional[str]]],
stop_at_reasoning: bool,
skip_genres: bool,
skip_caption: bool,
skip_language: bool,
generation_phase: str,
is_batch: bool = False,
metadata_temperature: Optional[float] = None,
codes_temperature: Optional[float] = None,
) -> Optional[MetadataConstrainedLogitsProcessor]:
"""Setup and configure constrained processor for generation"""
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
if not use_constrained_decoding and not use_phase_temperatures:
return None
# Reset processor state for new generation
self.constrained_processor.reset()
# Use shared processor, just update settings
self.constrained_processor.enabled = use_constrained_decoding
self.constrained_processor.debug = constrained_decoding_debug
# Phase temperatures only supported in single mode
if use_phase_temperatures:
self.constrained_processor.metadata_temperature = metadata_temperature
self.constrained_processor.codes_temperature = codes_temperature
else:
self.constrained_processor.metadata_temperature = None
self.constrained_processor.codes_temperature = None
self.constrained_processor.set_target_duration(target_duration)
# Batch mode uses default/disabled settings for these options
if is_batch:
self.constrained_processor.set_user_metadata(None)
self.constrained_processor.set_stop_at_reasoning(False)
self.constrained_processor.set_skip_genres(True)
self.constrained_processor.set_skip_caption(True)
self.constrained_processor.set_skip_language(True)
else:
# Single mode uses provided settings
self.constrained_processor.set_user_metadata(user_metadata)
self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
self.constrained_processor.set_skip_genres(skip_genres)
self.constrained_processor.set_skip_caption(skip_caption)
self.constrained_processor.set_skip_language(skip_language)
# Set generation phase for phase-aware processing
self.constrained_processor.set_generation_phase(generation_phase)
return self.constrained_processor
def _build_unconditional_prompt(
self,
caption: str,
lyrics: str,
cot_text: str,
negative_prompt: str,
generation_phase: str,
is_batch: bool = False,
) -> str:
"""Build unconditional prompt for CFG based on generation phase and batch mode"""
if is_batch or generation_phase == "codes":
# Codes phase or batch mode: use empty CoT in unconditional prompt
return self.build_formatted_prompt_with_cot(
caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
)
else:
# CoT phase (single mode only): unconditional prompt
# If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
return self.build_formatted_prompt(
caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
)
def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
"""Load PyTorch model from path and return (success, status_message)"""
try:
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
if not self.offload_to_cpu:
self.llm = self.llm.to(device).to(self.dtype)
else:
self.llm = self.llm.to("cpu").to(self.dtype)
self.llm.eval()
self.llm_backend = "pt"
self.llm_initialized = True
logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
status_msg = f"β
5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
return True, status_msg
except Exception as e:
return False, f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor:
"""Apply top-k filtering to logits"""
if top_k is not None and top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
return logits
def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
"""Apply top-p (nucleus) filtering to logits"""
if top_p is not None and 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
return logits
def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""Sample tokens from logits with temperature"""
if temperature > 0:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(1)
else:
return torch.argmax(logits, dim=-1)
def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool:
"""Check if any token in the batch is EOS or pad token"""
if torch.any(tokens == eos_token_id):
return True
if pad_token_id is not None and pad_token_id != eos_token_id:
if torch.any(tokens == pad_token_id):
return True
return False
def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor):
"""Update constrained processor state with generated tokens"""
if constrained_processor is not None:
for b in range(tokens.shape[0]):
constrained_processor.update_state(tokens[b].item())
def _forward_pass(
self,
model: Any,
generated_ids: torch.Tensor,
model_kwargs: Dict[str, Any],
past_key_values: Optional[Any],
use_cache: bool,
) -> Any:
"""Perform forward pass with KV cache support"""
if past_key_values is None:
outputs = model(
input_ids=generated_ids,
**model_kwargs,
use_cache=use_cache,
)
else:
outputs = model(
input_ids=generated_ids[:, -1:],
past_key_values=past_key_values,
**model_kwargs,
use_cache=use_cache,
)
return outputs
def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]:
"""Normalize batch input: convert single string to list and return (list, is_batch)"""
is_batch = isinstance(formatted_prompts, list)
if is_batch:
return formatted_prompts, is_batch
else:
return [formatted_prompts], is_batch
def initialize(
self,
checkpoint_dir: str,
lm_model_path: str,
backend: str = "vllm",
device: str = "auto",
offload_to_cpu: bool = False,
dtype: Optional[torch.dtype] = None,
) -> Tuple[str, bool]:
"""
Initialize 5Hz LM model
Args:
checkpoint_dir: Checkpoint directory path
lm_model_path: LM model path (relative to checkpoint_dir)
backend: Backend type ("vllm" or "pt")
device: Device type ("auto", "cuda", or "cpu")
offload_to_cpu: Whether to offload to CPU
dtype: Data type (if None, auto-detect based on device)
Returns:
(status_message, success)
"""
try:
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.offload_to_cpu = offload_to_cpu
# Set dtype based on device: bfloat16 for cuda, float32 for cpu
if dtype is None:
self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
else:
self.dtype = dtype
# If lm_model_path is None, use default
if lm_model_path is None:
lm_model_path = "acestep-5Hz-lm-1.7B"
logger.info(f"[initialize] lm_model_path is None, using default: {lm_model_path}")
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
if not os.path.exists(full_lm_model_path):
return f"β 5Hz LM model not found at {full_lm_model_path}", False
logger.info("loading 5Hz LM tokenizer... it may take 80~90s")
start_time = time.time()
# TODO: load tokenizer too slow, not found solution yet
llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
self.llm_tokenizer = llm_tokenizer
# Initialize shared constrained decoding processor (one-time initialization)
logger.info("Initializing constrained decoding processor...")
processor_start = time.time()
self.constrained_processor = MetadataConstrainedLogitsProcessor(
tokenizer=self.llm_tokenizer,
enabled=True,
debug=False,
)
logger.info(f"Constrained processor initialized in {time.time() - processor_start:.2f} seconds")
# Initialize based on user-selected backend
if backend == "vllm":
# Try to initialize with vllm
status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path)
logger.info(f"5Hz LM status message: {status_msg}")
# Check if initialization failed (status_msg starts with β)
if status_msg.startswith("β"):
# vllm initialization failed, fallback to PyTorch
if not self.llm_initialized:
logger.warning("vllm initialization failed, falling back to PyTorch backend")
success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
if not success:
return status_msg, False
status_msg = f"β
5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
# If vllm initialization succeeded, self.llm_initialized should already be True
else:
# Use PyTorch backend (pt)
success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
if not success:
return status_msg, False
return status_msg, True
except Exception as e:
return f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
"""Initialize 5Hz LM model using vllm backend"""
if not torch.cuda.is_available():
self.llm_initialized = False
logger.error("CUDA is not available. Please check your GPU setup.")
return "β CUDA is not available. Please check your GPU setup."
try:
from nanovllm import LLM, SamplingParams
except ImportError:
self.llm_initialized = False
logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
return "β nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
try:
current_device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(current_device)
torch.cuda.empty_cache()
gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization(
minimal_gpu=8,
min_ratio=0.2,
max_ratio=0.9
)
if low_gpu_memory_mode:
self.max_model_len = 2048
else:
self.max_model_len = 4096
logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
start_time = time.time()
self.llm = LLM(
model=model_path,
enforce_eager=False,
tensor_parallel_size=1,
max_model_len=self.max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
tokenizer=self.llm_tokenizer,
)
logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
self.llm_initialized = True
self.llm_backend = "vllm"
return f"β
5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
except Exception as e:
self.llm_initialized = False
return f"β Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
def _run_vllm(
self,
formatted_prompts: Union[str, List[str]],
temperature: float,
cfg_scale: float,
negative_prompt: str,
top_k: Optional[int],
top_p: Optional[float],
repetition_penalty: float,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
metadata_temperature: Optional[float] = None,
codes_temperature: Optional[float] = None,
target_duration: Optional[float] = None,
user_metadata: Optional[Dict[str, Optional[str]]] = None,
stop_at_reasoning: bool = False,
skip_genres: bool = True,
skip_caption: bool = False,
skip_language: bool = False,
generation_phase: str = "cot",
caption: str = "",
lyrics: str = "",
cot_text: str = "",
seeds: Optional[List[int]] = None,
) -> Union[str, List[str]]:
"""
Unified vllm generation function supporting both single and batch modes.
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
Returns a single string for single mode, or a list of strings for batch mode.
"""
from nanovllm import SamplingParams
# Determine if batch mode
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
batch_size = len(formatted_prompt_list)
# Determine effective temperature for sampler
# Batch mode doesn't support phase temperatures, so use simple temperature
# Single mode supports phase temperatures
use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
# Setup constrained processor
constrained_processor = self._setup_constrained_processor(
use_constrained_decoding=use_constrained_decoding or use_phase_temperatures,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
user_metadata=user_metadata,
stop_at_reasoning=stop_at_reasoning,
skip_genres=skip_genres,
skip_caption=skip_caption,
skip_language=skip_language,
generation_phase=generation_phase,
is_batch=is_batch,
metadata_temperature=metadata_temperature,
codes_temperature=codes_temperature,
)
sampling_params = SamplingParams(
max_tokens=self.max_model_len - 64,
temperature=effective_sampler_temp,
cfg_scale=cfg_scale,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
logits_processor=constrained_processor,
logits_processor_update_state=constrained_processor.update_state if constrained_processor else None,
)
if cfg_scale > 1.0:
# Build unconditional prompt based on generation phase
formatted_unconditional_prompt = self._build_unconditional_prompt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
negative_prompt=negative_prompt,
generation_phase=generation_phase,
is_batch=is_batch,
)
unconditional_prompts = [formatted_unconditional_prompt] * batch_size
outputs = self.llm.generate(
formatted_prompt_list,
sampling_params,
unconditional_prompts=unconditional_prompts,
)
else:
outputs = self.llm.generate(formatted_prompt_list, sampling_params)
# Extract text from outputs
output_texts = []
for output in outputs:
if hasattr(output, "outputs") and len(output.outputs) > 0:
output_texts.append(output.outputs[0].text)
elif hasattr(output, "text"):
output_texts.append(output.text)
elif isinstance(output, dict) and "text" in output:
output_texts.append(output["text"])
else:
output_texts.append(str(output))
# Return single string for single mode, list for batch mode
return output_texts[0] if not is_batch else output_texts
def _run_pt_single(
self,
formatted_prompt: str,
temperature: float,
cfg_scale: float,
negative_prompt: str,
top_k: Optional[int],
top_p: Optional[float],
repetition_penalty: float,
use_constrained_decoding: bool,
constrained_decoding_debug: bool,
target_duration: Optional[float],
user_metadata: Optional[Dict[str, Optional[str]]],
stop_at_reasoning: bool,
skip_genres: bool,
skip_caption: bool,
skip_language: bool,
generation_phase: str,
caption: str,
lyrics: str,
cot_text: str,
) -> str:
"""Internal helper function for single-item PyTorch generation."""
inputs = self.llm_tokenizer(
formatted_prompt,
return_tensors="pt",
padding=False,
truncation=True,
)
# Setup constrained processor
constrained_processor = self._setup_constrained_processor(
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
user_metadata=user_metadata,
stop_at_reasoning=stop_at_reasoning,
skip_genres=skip_genres,
skip_caption=skip_caption,
skip_language=skip_language,
generation_phase=generation_phase,
is_batch=False,
)
with self._load_model_context():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096)
if hasattr(self, "max_model_len"):
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
# Build logits processor list (only for CFG and repetition penalty)
logits_processor = self._build_logits_processor(repetition_penalty)
if cfg_scale > 1.0:
# Build unconditional prompt based on generation phase
formatted_unconditional_prompt = self._build_unconditional_prompt(
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
negative_prompt=negative_prompt,
generation_phase=generation_phase,
is_batch=False,
)
# Tokenize both prompts together to ensure same length (with left padding)
# Left padding is important for generation tasks
batch_texts = [formatted_prompt, formatted_unconditional_prompt]
original_padding_side = self.llm_tokenizer.padding_side
self.llm_tokenizer.padding_side = 'left'
batch_inputs_tokenized = self.llm_tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
)
self.llm_tokenizer.padding_side = original_padding_side
batch_inputs_tokenized = {k: v.to(self.device) for k, v in batch_inputs_tokenized.items()}
# Extract batch inputs
batch_input_ids = batch_inputs_tokenized['input_ids']
batch_attention_mask = batch_inputs_tokenized.get('attention_mask', None)
# Use custom CFG generation loop with constrained decoding
outputs = self._generate_with_cfg_custom(
batch_input_ids=batch_input_ids,
batch_attention_mask=batch_attention_mask,
max_new_tokens=max_new_tokens,
temperature=temperature,
cfg_scale=cfg_scale,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
streamer=None,
constrained_processor=constrained_processor,
)
# Extract only the conditional output (first in batch)
outputs = outputs[0:1] # Keep only conditional output
elif use_constrained_decoding:
# Use custom constrained decoding loop for non-CFG
outputs = self._generate_with_constrained_decoding(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
streamer=None,
constrained_processor=constrained_processor,
)
else:
# Generate without CFG using native generate() parameters
with torch.no_grad():
outputs = self.llm.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else 1.0,
do_sample=True if temperature > 0 else False,
top_k=top_k if top_k is not None and top_k > 0 else None,
top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
logits_processor=logits_processor if len(logits_processor) > 0 else None,
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
streamer=None,
)
# Decode the generated tokens
# outputs is a tensor with shape [batch_size, seq_len], extract first sequence
if isinstance(outputs, torch.Tensor):
if outputs.dim() == 2:
generated_ids = outputs[0]
else:
generated_ids = outputs
else:
generated_ids = outputs[0]
# Only decode the newly generated tokens (skip the input prompt)
# Use the original input length (before batch processing for CFG)
if cfg_scale > 1.0:
# In CFG case, we need to use the conditional input length from batch_inputs_tokenized
# Both sequences have the same length due to padding
input_length = batch_inputs_tokenized['input_ids'].shape[1]
else:
input_length = inputs["input_ids"].shape[1]
generated_ids = generated_ids[input_length:]
# Move to CPU for decoding
if generated_ids.is_cuda:
generated_ids = generated_ids.cpu()
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
return output_text
def _run_pt(
self,
formatted_prompts: Union[str, List[str]],
temperature: float,
cfg_scale: float,
negative_prompt: str,
top_k: Optional[int],
top_p: Optional[float],
repetition_penalty: float,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
target_duration: Optional[float] = None,
user_metadata: Optional[Dict[str, Optional[str]]] = None,
stop_at_reasoning: bool = False,
skip_genres: bool = True,
skip_caption: bool = False,
skip_language: bool = False,
generation_phase: str = "cot",
caption: str = "",
lyrics: str = "",
cot_text: str = "",
seeds: Optional[List[int]] = None,
) -> Union[str, List[str]]:
"""
Unified PyTorch generation function supporting both single and batch modes.
Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
Returns a single string for single mode, or a list of strings for batch mode.
Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently).
"""
# Determine if batch mode
formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
# For batch mode, process each item sequentially with different seeds
if is_batch:
output_texts = []
for i, formatted_prompt in enumerate(formatted_prompt_list):
# Set seed for this item if provided
if seeds and i < len(seeds):
torch.manual_seed(seeds[i])
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seeds[i])
# Generate using single-item method with batch-mode defaults
output_text = self._run_pt_single(
formatted_prompt=formatted_prompt,
temperature=temperature,
cfg_scale=cfg_scale,
negative_prompt=negative_prompt,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
user_metadata=None,
stop_at_reasoning=False,
skip_genres=True,
skip_caption=True,
skip_language=True,
generation_phase=generation_phase,
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
)
output_texts.append(output_text)
return output_texts
# Single mode: process the formatted prompt
formatted_prompt = formatted_prompt_list[0]
return self._run_pt_single(
formatted_prompt=formatted_prompt,
temperature=temperature,
cfg_scale=cfg_scale,
negative_prompt=negative_prompt,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
user_metadata=user_metadata,
stop_at_reasoning=stop_at_reasoning,
skip_genres=skip_genres,
skip_caption=skip_caption,
skip_language=skip_language,
generation_phase=generation_phase,
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
)
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
"""Check if all required metadata are present."""
if user_metadata is None:
return False
if 'bpm' in user_metadata and 'keyscale' in user_metadata and 'timesignature' in user_metadata and 'duration' in user_metadata:
return True
return False
def _format_metadata_as_cot(self, metadata: Dict[str, Any]) -> str:
"""
Format parsed metadata as CoT text using YAML format (matching training format).
Args:
metadata: Dictionary with keys: bpm, caption, duration, keyscale, language, timesignature
Returns:
Formatted CoT text: "<think>\n{yaml_content}\n</think>"
"""
# Build cot_items dict with only non-None values
cot_items = {}
for key in ['bpm', 'caption', 'duration', 'keyscale', 'language', 'timesignature']:
if key in metadata and metadata[key] is not None:
value = metadata[key]
if key == "timesignature" and value.endswith("/4"):
value = value.split("/")[0]
if isinstance(value, str) and value.isdigit():
value = int(value)
cot_items[key] = value
# Format as YAML (sorted keys, unicode support)
if len(cot_items) > 0:
cot_yaml = yaml.dump(cot_items, allow_unicode=True, sort_keys=True).strip()
else:
cot_yaml = ""
return f"<think>\n{cot_yaml}\n</think>"
def generate_with_stop_condition(
self,
caption: str,
lyrics: str,
infer_type: str,
temperature: float = 0.85,
cfg_scale: float = 1.0,
negative_prompt: str = "NO USER INPUT",
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
target_duration: Optional[float] = None,
user_metadata: Optional[Dict[str, Optional[str]]] = None,
use_cot_metas: bool = True,
use_cot_caption: bool = True,
use_cot_language: bool = True,
batch_size: Optional[int] = None,
seeds: Optional[List[int]] = None,
progress=None,
) -> Dict[str, Any]:
"""Two-phase LM generation: CoT generation followed by audio codes generation.
- infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
- infer_type='llm_dit': Phase 1 + Phase 2 - generate CoT then audio codes
Args:
target_duration: Target duration in seconds for codes generation constraint.
5 codes = 1 second. If specified, blocks EOS until target reached.
user_metadata: User-provided metadata fields (e.g. bpm/duration/keyscale/timesignature).
If specified, constrained decoding will inject these values directly.
use_cot_caption: Whether to generate caption in CoT (default True).
use_cot_language: Whether to generate language in CoT (default True).
batch_size: Optional batch size for batch generation. If None or 1, returns single result.
If > 1, returns batch results (lists).
seeds: Optional list of seeds for batch generation (for reproducibility).
Only used when batch_size > 1. TODO: not used yet
Returns:
Dictionary containing:
- metadata: Dict or List[Dict] - Generated metadata
- audio_codes: str or List[str] - Generated audio codes
- success: bool - Whether generation succeeded
- error: Optional[str] - Error message if failed
- extra_outputs: Dict with time_costs and other info
"""
if progress is None:
def progress(*args, **kwargs):
pass
infer_type = (infer_type or "").strip().lower()
if infer_type not in {"dit", "llm_dit"}:
error_msg = f"invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
return {
"metadata": [] if (batch_size and batch_size > 1) else {},
"audio_codes": [] if (batch_size and batch_size > 1) else "",
"success": False,
"error": error_msg,
"extra_outputs": {"time_costs": {}},
}
# Determine if batch mode
is_batch = batch_size and batch_size > 1
actual_batch_size = batch_size if is_batch else 1
# Initialize variables
metadata = {}
audio_codes = ""
has_all_metas = self.has_all_metas(user_metadata)
phase1_time = 0.0
phase2_time = 0.0
# Handle seeds for batch mode
if is_batch:
if seeds is None:
seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
elif len(seeds) < actual_batch_size:
seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))]
else:
seeds = seeds[:actual_batch_size]
# ========== PHASE 1: CoT Generation ==========
# Skip CoT if all metadata are user-provided OR caption is already formatted
progress(0.1, f"Phase 1: Generating CoT metadata (once for all items)...")
if not has_all_metas and use_cot_metas:
if is_batch:
logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
else:
logger.info("Phase 1: Generating CoT metadata...")
phase1_start = time.time()
# Build formatted prompt for CoT phase
formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
# Generate CoT (stop at </think>)
cot_output_text, status = self.generate_from_formatted_prompt(
formatted_prompt=formatted_prompt,
cfg={
"temperature": temperature,
"cfg_scale": cfg_scale,
"negative_prompt": negative_prompt,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"target_duration": None, # No duration constraint for CoT phase
"user_metadata": user_metadata,
"skip_caption": not use_cot_caption,
"skip_language": not use_cot_language,
"skip_genres": True, # Generate genres
"generation_phase": "cot",
# Pass context for building unconditional prompt in CoT phase
"caption": caption,
"lyrics": lyrics,
},
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=True, # Always stop at </think> in Phase 1
)
phase1_time = time.time() - phase1_start
if not cot_output_text:
return {
"metadata": [] if is_batch else {},
"audio_codes": [] if is_batch else "",
"success": False,
"error": status,
"extra_outputs": {"time_costs": {"phase1_time": phase1_time}},
}
# Parse metadata from CoT output
metadata, _ = self.parse_lm_output(cot_output_text)
if is_batch:
logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
else:
logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
else:
# Use user-provided metadata
if is_batch:
logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
else:
logger.info("Phase 1: Using user-provided metadata (skipping generation)")
metadata = {k: v for k, v in user_metadata.items() if v is not None}
# If infer_type is 'dit', stop here and return only metadata
if infer_type == "dit":
if is_batch:
metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
return {
"metadata": metadata_list,
"audio_codes": [""] * actual_batch_size,
"success": True,
"error": None,
"extra_outputs": {
"time_costs": {
"phase1_time": phase1_time,
"total_time": phase1_time,
}
},
}
else:
return {
"metadata": metadata,
"audio_codes": "",
"success": True,
"error": None,
"extra_outputs": {
"time_costs": {
"phase1_time": phase1_time,
"total_time": phase1_time,
}
},
}
# ========== PHASE 2: Audio Codes Generation ==========
if is_batch:
logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...")
else:
logger.info("Phase 2: Generating audio codes...")
phase2_start = time.time()
# Format metadata as CoT using YAML (matching training format)
cot_text = self._format_metadata_as_cot(metadata)
# Build formatted prompt with CoT for codes generation phase
formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
progress(0.5, f"Phase 2: Generating audio codes for {actual_batch_size} items...")
if is_batch:
# Batch mode: generate codes for all items
formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
# Call backend-specific batch generation
try:
if self.llm_backend == "vllm":
codes_outputs = self._run_vllm(
formatted_prompts=formatted_prompts,
temperature=temperature,
cfg_scale=cfg_scale,
negative_prompt=negative_prompt,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
generation_phase="codes",
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
seeds=seeds,
)
else: # pt backend
codes_outputs = self._run_pt(
formatted_prompts=formatted_prompts,
temperature=temperature,
cfg_scale=cfg_scale,
negative_prompt=negative_prompt,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
generation_phase="codes",
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
seeds=seeds,
)
except Exception as e:
error_msg = f"Error in batch codes generation: {str(e)}"
logger.error(error_msg)
return {
"metadata": [],
"audio_codes": [],
"success": False,
"error": error_msg,
"extra_outputs": {
"time_costs": {
"phase1_time": phase1_time,
"phase2_time": 0.0,
"total_time": phase1_time,
}
},
}
# Parse audio codes from each output
audio_codes_list = []
metadata_list = []
for output_text in codes_outputs:
_, audio_codes_item = self.parse_lm_output(output_text)
audio_codes_list.append(audio_codes_item)
metadata_list.append(metadata.copy()) # Same metadata for all
phase2_time = time.time() - phase2_start
# Log results
codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
total_time = phase1_time + phase2_time
return {
"metadata": metadata_list,
"audio_codes": audio_codes_list,
"success": True,
"error": None,
"extra_outputs": {
"time_costs": {
"phase1_time": phase1_time,
"phase2_time": phase2_time,
"total_time": total_time,
},
"codes_counts": codes_counts,
"total_codes": sum(codes_counts),
},
}
else:
# Single mode: generate codes for one item
codes_output_text, status = self.generate_from_formatted_prompt(
formatted_prompt=formatted_prompt_with_cot,
cfg={
"temperature": temperature,
"cfg_scale": cfg_scale,
"negative_prompt": negative_prompt,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"target_duration": target_duration,
"user_metadata": None, # No user metadata injection in Phase 2
"skip_caption": True, # Skip caption since CoT is already included
"skip_language": True, # Skip language since CoT is already included
"generation_phase": "codes",
# Pass context for building unconditional prompt in codes phase
"caption": caption,
"lyrics": lyrics,
"cot_text": cot_text,
},
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=False, # Generate codes until EOS
)
if not codes_output_text:
total_time = phase1_time + phase2_time
return {
"metadata": metadata,
"audio_codes": "",
"success": False,
"error": status,
"extra_outputs": {
"time_costs": {
"phase1_time": phase1_time,
"phase2_time": phase2_time,
"total_time": total_time,
}
},
}
phase2_time = time.time() - phase2_start
# Parse audio codes from output (metadata should be same as Phase 1)
_, audio_codes = self.parse_lm_output(codes_output_text)
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
total_time = phase1_time + phase2_time
return {
"metadata": metadata,
"audio_codes": audio_codes,
"success": True,
"error": None,
"extra_outputs": {
"time_costs": {
"phase1_time": phase1_time,
"phase2_time": phase2_time,
"total_time": total_time,
},
"codes_count": codes_count,
},
}
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
"""
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
Raises a ValueError if the tokenizer is not initialized.
Args:
caption: Caption text
lyrics: Lyrics text
is_negative_prompt: If True, builds unconditional prompt for CFG
generation_phase: "cot" or "codes" - affects unconditional prompt format
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True)
Example:
prompt = handler.build_formatted_prompt("calm piano", "hello world")
"""
if self.llm_tokenizer is None:
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
if is_negative_prompt:
# Unconditional prompt for CFG
# Check if user provided a meaningful negative prompt (not the default)
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
if generation_phase == "cot":
# CoT phase unconditional prompt
if has_negative_prompt:
# If negative prompt provided, use it as caption
prompt = f"# Caption\n{negative_prompt}\n\n# Lyric\n{lyrics}\n"
else:
# No negative prompt: remove caption, keep only lyrics
prompt = f"# Lyric\n{lyrics}\n"
else:
# Codes phase: will be handled by build_formatted_prompt_with_cot
# For backward compatibility, use simple caption as before
prompt = caption
else:
# Conditional prompt: include both caption and lyrics
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
return self.llm_tokenizer.apply_chat_template(
[
{"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"},
{"role": "user", "content": prompt},
],
tokenize=False,
add_generation_prompt=True,
)
def build_formatted_prompt_with_cot(self, caption: str, lyrics: str, cot_text: str, is_negative_prompt: bool = False, negative_prompt: str = "NO USER INPUT") -> str:
"""
Build the chat-formatted prompt for codes generation phase with pre-generated CoT.
Args:
caption: Caption text
lyrics: Lyrics text
cot_text: Pre-generated CoT text (e.g., "<think>\\nbpm: 120\\n...\\n</think>")
is_negative_prompt: If True, uses empty CoT for CFG unconditional prompt
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True)
Returns:
Formatted prompt string
Example:
cot = "<think>\\nbpm: 120\\ncaption: calm piano\\n...\\n</think>"
prompt = handler.build_formatted_prompt_with_cot("calm piano", "hello", cot)
"""
if self.llm_tokenizer is None:
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
if is_negative_prompt:
# Unconditional prompt for codes phase
# Check if user provided a meaningful negative prompt
has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
# Use empty CoT for unconditional
cot_for_prompt = "<think>\n</think>"
if has_negative_prompt:
# If negative prompt provided, use it as caption
caption_for_prompt = negative_prompt
else:
# No negative prompt: use original caption
caption_for_prompt = caption
else:
# Conditional prompt: use the full CoT and original caption
cot_for_prompt = cot_text
caption_for_prompt = caption
# Build user prompt with caption and lyrics ONLY (no COT)
# COT should be in the assistant's message, not user's
user_prompt = f"# Caption\n{caption_for_prompt}\n\n# Lyric\n{lyrics}\n"
# Build the chat with assistant message containing the COT
# The model will continue generation after the COT
formatted = self.llm_tokenizer.apply_chat_template(
[
{"role": "system", "content": f"# Instruction\n{DEFAULT_LM_INSTRUCTION}\n\n"},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": cot_for_prompt},
],
tokenize=False,
add_generation_prompt=False, # Don't add generation prompt, COT is already in assistant
)
# Add a newline after </think> so model generates audio codes on next line
if not formatted.endswith('\n'):
formatted += '\n'
return formatted
def build_formatted_prompt_for_understanding(
self,
audio_codes: str,
is_negative_prompt: bool = False,
negative_prompt: str = "NO USER INPUT"
) -> str:
"""
Build the chat-formatted prompt for audio understanding from codes.
This is the reverse of generation: given audio codes, generate metadata and lyrics.
Args:
audio_codes: Audio code string (e.g., "<|audio_code_123|><|audio_code_456|>...")
is_negative_prompt: If True, builds unconditional prompt for CFG
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True)
Returns:
Formatted prompt string
Example:
codes = "<|audio_code_18953|><|audio_code_13833|>..."
prompt = handler.build_formatted_prompt_for_understanding(codes)
"""
if self.llm_tokenizer is None:
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
# For understanding task, user provides audio codes
# Unconditional prompt uses negative_prompt or empty string
if is_negative_prompt:
user_content = negative_prompt if negative_prompt and negative_prompt.strip() else ""
else:
user_content = audio_codes
return self.llm_tokenizer.apply_chat_template(
[
{
"role": "system",
"content": f"# Instruction\n{DEFAULT_LM_UNDERSTAND_INSTRUCTION}\n\n"
},
{
"role": "user",
"content": user_content
},
],
tokenize=False,
add_generation_prompt=True,
)
def understand_audio_from_codes(
self,
audio_codes: str,
temperature: float = 0.3,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> Tuple[Dict[str, Any], str]:
"""
Understand audio codes and generate metadata + lyrics.
This is the reverse of the normal generation flow:
- Input: Audio codes
- Output: Metadata (bpm, caption, duration, etc.) + Lyrics
Note: cfg_scale and negative_prompt are not supported in understand mode.
Args:
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
temperature: Sampling temperature for generation
top_k: Top-K sampling (None = disabled)
top_p: Top-P (nucleus) sampling (None = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
Returns:
Tuple of (metadata_dict, status_message)
metadata_dict contains:
- bpm: int or str
- caption: str
- duration: int or str
- keyscale: str
- language: str
- timesignature: str
- lyrics: str (extracted from output after </think>)
Example:
codes = "<|audio_code_18953|><|audio_code_13833|>..."
metadata, status = handler.understand_audio_from_codes(codes)
print(metadata['caption']) # "A cinematic orchestral piece..."
print(metadata['lyrics']) # "[Intro: ...]\\n..."
"""
if not getattr(self, "llm_initialized", False):
return {}, "β 5Hz LM not initialized. Please initialize it first."
if not audio_codes or not audio_codes.strip():
return {}, "β No audio codes provided. Please paste audio codes first."
logger.info(f"Understanding audio codes (length: {len(audio_codes)} chars)")
# Build formatted prompt for understanding
formatted_prompt = self.build_formatted_prompt_for_understanding(audio_codes)
print(f"formatted_prompt: {formatted_prompt}")
# Generate using constrained decoding (understand phase)
# We want to generate metadata first (CoT), then lyrics (natural text)
# Note: cfg_scale and negative_prompt are not used in understand mode
output_text, status = self.generate_from_formatted_prompt(
formatted_prompt=formatted_prompt,
cfg={
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"target_duration": None, # No duration constraint for understanding
"user_metadata": None, # No user metadata injection
"skip_caption": False, # Generate caption
"skip_language": False, # Generate language
"skip_genres": False, # Generate genres
"generation_phase": "understand", # Understanding phase: generate CoT metadata, then free-form lyrics
# Context for building unconditional prompt
"caption": "",
"lyrics": "",
},
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=False, # Continue after </think> to generate lyrics
)
if not output_text:
return {}, status
# Parse metadata and extract lyrics
metadata, _ = self.parse_lm_output(output_text)
# Extract lyrics section (everything after </think>)
lyrics = self._extract_lyrics_from_output(output_text)
if lyrics:
metadata['lyrics'] = lyrics
logger.info(f"Understanding completed. Generated {len(metadata)} metadata fields")
if constrained_decoding_debug:
logger.debug(f"Generated metadata: {list(metadata.keys())}")
logger.debug(f"Output text preview: {output_text[:200]}...")
status_msg = f"β
Understanding completed successfully\nGenerated fields: {', '.join(metadata.keys())}"
return metadata, status_msg
def _extract_lyrics_from_output(self, output_text: str) -> str:
"""
Extract lyrics section from LLM output.
The lyrics appear after the </think> tag and typically start with "# Lyric"
or directly with lyric content.
Args:
output_text: Full LLM output text
Returns:
Extracted lyrics string, or empty string if no lyrics found
"""
import re
# Find the </think> tag
think_end_pattern = r'</think>'
match = re.search(think_end_pattern, output_text)
if not match:
# No </think> tag found, no lyrics
return ""
# Extract everything after </think>
after_think = output_text[match.end():].strip()
if not after_think:
return ""
# Remove "# Lyric" header if present
lyric_header_pattern = r'^#\s*Lyri[c|cs]?\s*\n'
after_think = re.sub(lyric_header_pattern, '', after_think, flags=re.IGNORECASE)
# Remove <|im_end|> tag at the end if present
after_think = re.sub(r'<\|im_end\|>\s*$', '', after_think)
return after_think.strip()
def build_formatted_prompt_for_inspiration(
self,
query: str,
instrumental: bool = False,
is_negative_prompt: bool = False,
negative_prompt: str = "NO USER INPUT"
) -> str:
"""
Build the chat-formatted prompt for inspiration/simple mode.
This generates a complete sample (caption, lyrics, metadata) from a user's
natural language music description query.
Args:
query: User's natural language music description
instrumental: Whether to generate instrumental music (no vocals)
is_negative_prompt: If True, builds unconditional prompt for CFG
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True)
Returns:
Formatted prompt string
Example:
query = "a soft Bengali love song for a quiet evening"
prompt = handler.build_formatted_prompt_for_inspiration(query, instrumental=False)
"""
if self.llm_tokenizer is None:
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
# Build user content with query and instrumental flag
instrumental_str = "true" if instrumental else "false"
if is_negative_prompt:
# For CFG unconditional prompt
user_content = negative_prompt if negative_prompt and negative_prompt.strip() else ""
else:
# Normal prompt: query + instrumental flag
user_content = f"{query}\n\ninstrumental: {instrumental_str}"
return self.llm_tokenizer.apply_chat_template(
[
{
"role": "system",
"content": f"# Instruction\n{DEFAULT_LM_INSPIRED_INSTRUCTION}\n\n"
},
{
"role": "user",
"content": user_content
},
],
tokenize=False,
add_generation_prompt=True,
)
def create_sample_from_query(
self,
query: str,
instrumental: bool = False,
vocal_language: Optional[str] = None,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> Tuple[Dict[str, Any], str]:
"""
Create a complete music sample from a user's natural language query.
This is the "Simple Mode" / "Inspiration Mode" feature that generates:
- Metadata (bpm, caption, duration, keyscale, language, timesignature)
- Lyrics (unless instrumental=True)
Args:
query: User's natural language music description
instrumental: Whether to generate instrumental music (no vocals)
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
If provided and not "unknown", it will be used.
temperature: Sampling temperature for generation (0.0-2.0)
top_k: Top-K sampling (None = disabled)
top_p: Top-P (nucleus) sampling (None = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding
constrained_decoding_debug: Whether to enable debug logging
Returns:
Tuple of (metadata_dict, status_message)
metadata_dict contains:
- bpm: int or str
- caption: str
- duration: int or str
- keyscale: str
- language: str
- timesignature: str
- lyrics: str (extracted from output after </think>)
- instrumental: bool (echoed back)
Example:
query = "a soft Bengali love song for a quiet evening"
metadata, status = handler.create_sample_from_query(query, instrumental=False, vocal_language="bn")
print(metadata['caption']) # "A gentle romantic acoustic pop ballad..."
print(metadata['lyrics']) # "[Intro: ...]\\n..."
"""
if not getattr(self, "llm_initialized", False):
return {}, "β 5Hz LM not initialized. Please initialize it first."
if not query or not query.strip():
query = "NO USER INPUT"
logger.info(f"Creating sample from query: {query[:100]}... (instrumental={instrumental}, vocal_language={vocal_language})")
# Build formatted prompt for inspiration
formatted_prompt = self.build_formatted_prompt_for_inspiration(
query=query,
instrumental=instrumental,
)
logger.debug(f"Formatted prompt for inspiration: {formatted_prompt}")
# Build user_metadata if vocal_language is specified and is not "unknown"
user_metadata = None
skip_language = False
if vocal_language and vocal_language.strip() and vocal_language.strip().lower() != "unknown":
# Use the specified language for constrained decoding
user_metadata = {"language": vocal_language.strip()}
# skip_language = True # Skip language generation since we're injecting it
logger.info(f"Using user-specified language: {vocal_language.strip()}")
# Generate using constrained decoding (inspiration phase)
# Similar to understand mode - generate metadata first (CoT), then lyrics
# Note: cfg_scale and negative_prompt are not used in create_sample mode
output_text, status = self.generate_from_formatted_prompt(
formatted_prompt=formatted_prompt,
cfg={
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"target_duration": None, # No duration constraint
"user_metadata": user_metadata, # Inject language if specified
"skip_caption": False, # Generate caption
"skip_language": False,
"skip_genres": False, # Generate genres
"generation_phase": "understand", # Use understand phase for metadata + free-form lyrics
"caption": "",
"lyrics": "",
},
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=False, # Continue after </think> to generate lyrics
)
if not output_text:
return {}, status
# Parse metadata and extract lyrics
metadata, _ = self.parse_lm_output(output_text)
# Extract lyrics section (everything after </think>)
lyrics = self._extract_lyrics_from_output(output_text)
if lyrics:
metadata['lyrics'] = lyrics
elif instrumental:
# For instrumental, set empty lyrics or placeholder
metadata['lyrics'] = "[Instrumental]"
# Echo back the instrumental flag
metadata['instrumental'] = instrumental
logger.info(f"Sample created successfully. Generated {metadata} fields")
if constrained_decoding_debug:
logger.debug(f"Generated metadata: {list(metadata.keys())}")
logger.debug(f"Output text preview: {output_text[:300]}...")
status_msg = f"β
Sample created successfully\nGenerated fields: {metadata}"
return metadata, status_msg
def build_formatted_prompt_for_format(
self,
caption: str,
lyrics: str,
is_negative_prompt: bool = False,
negative_prompt: str = "NO USER INPUT"
) -> str:
"""
Build the chat-formatted prompt for format/rewrite mode.
This formats user-provided caption and lyrics into a more detailed and specific
musical description with metadata.
Args:
caption: User's caption/description of the music
lyrics: User's lyrics
is_negative_prompt: If True, builds unconditional prompt for CFG
negative_prompt: Negative prompt for CFG (used when is_negative_prompt=True)
Returns:
Formatted prompt string
Example:
caption = "Latin pop, reggaeton, flamenco-pop"
lyrics = "[Verse 1]\\nTengo un nudo..."
prompt = handler.build_formatted_prompt_for_format(caption, lyrics)
"""
if self.llm_tokenizer is None:
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
if is_negative_prompt:
# For CFG unconditional prompt
user_content = negative_prompt if negative_prompt and negative_prompt.strip() else ""
else:
# Normal prompt: caption + lyrics
user_content = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}"
return self.llm_tokenizer.apply_chat_template(
[
{
"role": "system",
"content": f"# Instruction\n{DEFAULT_LM_REWRITE_INSTRUCTION}\n\n"
},
{
"role": "user",
"content": user_content
},
],
tokenize=False,
add_generation_prompt=True,
)
def format_sample_from_input(
self,
caption: str,
lyrics: str,
user_metadata: Optional[Dict[str, Any]] = None,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> Tuple[Dict[str, Any], str]:
"""
Format user-provided caption and lyrics into structured music metadata.
This is the "Format" feature that takes user input and generates:
- Enhanced caption with detailed music description
- Metadata (bpm, duration, keyscale, language, timesignature)
- Formatted lyrics (preserved from input)
Note: cfg_scale and negative_prompt are not supported in format mode.
Args:
caption: User's caption/description (e.g., "Latin pop, reggaeton")
lyrics: User's lyrics with structure tags
user_metadata: Optional dict with user-provided metadata to constrain decoding.
Supported keys: bpm, duration, keyscale, timesignature, language
temperature: Sampling temperature for generation (0.0-2.0)
top_k: Top-K sampling (None = disabled)
top_p: Top-P (nucleus) sampling (None = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding
constrained_decoding_debug: Whether to enable debug logging
Returns:
Tuple of (metadata_dict, status_message)
metadata_dict contains:
- bpm: int or str
- caption: str (enhanced)
- duration: int or str
- keyscale: str
- language: str
- timesignature: str
- lyrics: str (from input, possibly formatted)
Example:
caption = "Latin pop, reggaeton, flamenco-pop"
lyrics = "[Verse 1]\\nTengo un nudo en la garganta..."
metadata, status = handler.format_sample_from_input(caption, lyrics)
print(metadata['caption']) # "A dramatic and powerful Latin pop track..."
print(metadata['bpm']) # 100
"""
if not getattr(self, "llm_initialized", False):
return {}, "β 5Hz LM not initialized. Please initialize it first."
if not caption or not caption.strip():
caption = "NO USER INPUT"
if not lyrics or not lyrics.strip():
lyrics = "[Instrumental]"
logger.info(f"Formatting sample from input: caption={caption[:50]}..., lyrics length={len(lyrics)}")
# Build formatted prompt for format task
formatted_prompt = self.build_formatted_prompt_for_format(
caption=caption,
lyrics=lyrics,
)
logger.debug(f"Formatted prompt for format: {formatted_prompt}")
# Build constrained decoding metadata from user_metadata
constrained_metadata = None
if user_metadata:
constrained_metadata = {}
if user_metadata.get('bpm') is not None:
try:
bpm_val = int(user_metadata['bpm'])
if bpm_val > 0:
constrained_metadata['bpm'] = bpm_val
except (ValueError, TypeError):
pass
if user_metadata.get('duration') is not None:
try:
dur_val = int(user_metadata['duration'])
if dur_val > 0:
constrained_metadata['duration'] = dur_val
except (ValueError, TypeError):
pass
if user_metadata.get('keyscale'):
constrained_metadata['keyscale'] = user_metadata['keyscale']
if user_metadata.get('timesignature'):
constrained_metadata['timesignature'] = user_metadata['timesignature']
if user_metadata.get('language'):
constrained_metadata['language'] = user_metadata['language']
# Only use if we have at least one field
if not constrained_metadata:
constrained_metadata = None
else:
logger.info(f"Using user-provided metadata constraints: {constrained_metadata}")
# Generate using constrained decoding (format phase)
# Similar to understand/inspiration mode - generate metadata first (CoT), then formatted lyrics
# Note: cfg_scale and negative_prompt are not used in format mode
output_text, status = self.generate_from_formatted_prompt(
formatted_prompt=formatted_prompt,
cfg={
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"target_duration": None, # No duration constraint for generation length
"user_metadata": constrained_metadata, # Inject user-provided metadata
"skip_caption": False, # Generate caption
"skip_language": constrained_metadata.get('language') is not None if constrained_metadata else False,
"skip_genres": False, # Generate genres
"generation_phase": "understand", # Use understand phase for metadata + free-form lyrics
"caption": "",
"lyrics": "",
},
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
stop_at_reasoning=False, # Continue after </think> to get formatted lyrics
)
if not output_text:
return {}, status
# Parse metadata and extract lyrics
metadata, _ = self.parse_lm_output(output_text)
# Extract formatted lyrics section (everything after </think>)
formatted_lyrics = self._extract_lyrics_from_output(output_text)
if formatted_lyrics:
metadata['lyrics'] = formatted_lyrics
else:
# If no lyrics generated, keep original input
metadata['lyrics'] = lyrics
logger.info(f"Format completed successfully. Generated {metadata} fields")
if constrained_decoding_debug:
logger.debug(f"Generated metadata: {list(metadata.keys())}")
logger.debug(f"Output text preview: {output_text[:300]}...")
status_msg = f"β
Format completed successfully\nGenerated fields: {', '.join(metadata.keys())}"
return metadata, status_msg
def generate_from_formatted_prompt(
self,
formatted_prompt: str,
cfg: Optional[Dict[str, Any]] = None,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
stop_at_reasoning: bool = False,
) -> Tuple[str, str]:
"""
Generate raw LM text output from a pre-built formatted prompt.
Args:
formatted_prompt: Prompt that is already formatted by `build_formatted_prompt`.
cfg: Optional dict supporting keys:
- temperature (float)
- cfg_scale (float)
- negative_prompt (str) used when cfg_scale > 1
- top_k (int), top_p (float), repetition_penalty (float)
- target_duration (float): Target duration in seconds for codes generation
- generation_phase (str): "cot" or "codes" for phase-aware CFG
use_constrained_decoding: Whether to use FSM-based constrained decoding
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
stop_at_reasoning: If True, stop generation immediately after </think> tag (no audio codes)
Returns:
(output_text, status_message)
Example:
prompt = handler.build_formatted_prompt(caption, lyric)
text, status = handler.generate_from_formatted_prompt(prompt, {"temperature": 0.7})
"""
if not getattr(self, "llm_initialized", False):
return "", "β 5Hz LM not initialized. Please initialize it first."
if self.llm is None or self.llm_tokenizer is None:
return "", "β 5Hz LM is missing model or tokenizer."
cfg = cfg or {}
temperature = cfg.get("temperature", 0.6)
cfg_scale = cfg.get("cfg_scale", 1.0)
negative_prompt = cfg.get("negative_prompt", "NO USER INPUT")
top_k = cfg.get("top_k")
top_p = cfg.get("top_p")
repetition_penalty = cfg.get("repetition_penalty", 1.0)
target_duration = cfg.get("target_duration")
user_metadata = cfg.get("user_metadata") # User-provided metadata fields
skip_caption = cfg.get("skip_caption", False) # Skip caption generation in CoT
skip_language = cfg.get("skip_language", False) # Skip language generation in CoT
skip_genres = cfg.get("skip_genres", False) # Skip genres generation in CoT
generation_phase = cfg.get("generation_phase", "cot") # "cot" or "codes"
# Additional context for codes phase unconditional prompt building
caption = cfg.get("caption", "")
lyrics = cfg.get("lyrics", "")
cot_text = cfg.get("cot_text", "")
try:
if self.llm_backend == "vllm":
output_text = self._run_vllm(
formatted_prompts=formatted_prompt,
temperature=temperature,
cfg_scale=cfg_scale,
negative_prompt=negative_prompt,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
user_metadata=user_metadata,
stop_at_reasoning=stop_at_reasoning,
skip_genres=skip_genres,
skip_caption=skip_caption,
skip_language=skip_language,
generation_phase=generation_phase,
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
)
return output_text, f"β
Generated successfully (vllm) | length={len(output_text)}"
# PyTorch backend
output_text = self._run_pt(
formatted_prompts=formatted_prompt,
temperature=temperature,
cfg_scale=cfg_scale,
negative_prompt=negative_prompt,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
target_duration=target_duration,
user_metadata=user_metadata,
stop_at_reasoning=stop_at_reasoning,
skip_genres=skip_genres,
skip_caption=skip_caption,
skip_language=skip_language,
generation_phase=generation_phase,
caption=caption,
lyrics=lyrics,
cot_text=cot_text,
)
return output_text, f"β
Generated successfully (pt) | length={len(output_text)}"
except Exception as e:
return "", f"β Error generating from formatted prompt: {e}"
def _generate_with_constrained_decoding(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor],
max_new_tokens: int,
temperature: float,
top_k: Optional[int],
top_p: Optional[float],
repetition_penalty: float,
pad_token_id: int,
streamer: Optional[BaseStreamer],
constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None,
) -> torch.Tensor:
"""
Custom generation loop with constrained decoding support (non-CFG).
This allows us to call update_state() after each token generation.
"""
model = self.llm
device = self.device
# Initialize generated sequences
generated_ids = input_ids.clone()
if attention_mask is not None:
attn_mask = attention_mask.clone()
else:
attn_mask = torch.ones_like(input_ids)
# Prepare model inputs
model_kwargs = {'attention_mask': attn_mask}
# Past key values for KV cache
past_key_values = None
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
# Get EOS token ID
eos_token_id = self.llm_tokenizer.eos_token_id
if eos_token_id is None:
eos_token_id = pad_token_id
# Build logits processor for repetition penalty
logits_processor = self._build_logits_processor(repetition_penalty)
with torch.no_grad():
for step in range(max_new_tokens):
# Forward pass
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
# Get logits for the last position
next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
# Apply constrained processor FIRST (modifies logits based on FSM state)
if constrained_processor is not None:
next_token_logits = constrained_processor(generated_ids, next_token_logits)
# Apply other logits processors (repetition penalty)
for processor in logits_processor:
next_token_logits = processor(generated_ids, next_token_logits)
# Apply top-k and top-p filtering
next_token_logits = self._apply_top_k_filter(next_token_logits, top_k)
next_token_logits = self._apply_top_p_filter(next_token_logits, top_p)
# Apply temperature and sample
next_tokens = self._sample_tokens(next_token_logits, temperature)
# Update constrained processor state
self._update_constrained_processor_state(constrained_processor, next_tokens)
# Check for EOS token
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
# Append token to sequence
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed], dim=1)
attn_mask = torch.cat([attn_mask, torch.ones((input_ids.shape[0], 1), device=device, dtype=attn_mask.dtype)], dim=1)
model_kwargs['attention_mask'] = attn_mask
# Update KV cache
if use_cache and hasattr(outputs, 'past_key_values'):
past_key_values = outputs.past_key_values
# Update streamer
if streamer is not None:
streamer.put(next_tokens_unsqueezed)
if should_stop:
break
if streamer is not None:
streamer.end()
return generated_ids
def _generate_with_cfg_custom(
self,
batch_input_ids: torch.Tensor,
batch_attention_mask: Optional[torch.Tensor],
max_new_tokens: int,
temperature: float,
cfg_scale: float,
top_k: Optional[int],
top_p: Optional[float],
repetition_penalty: float,
pad_token_id: int,
streamer: Optional[BaseStreamer],
constrained_processor: Optional[MetadataConstrainedLogitsProcessor] = None,
) -> torch.Tensor:
"""
Custom CFG generation loop that:
1. Processes both conditional and unconditional sequences in parallel
2. Applies CFG formula to logits
3. Samples tokens only for conditional sequences
4. Applies the same sampled tokens to both conditional and unconditional sequences
5. Optionally applies constrained decoding via FSM-based logits processor
Batch format: [cond_input, uncond_input]
"""
model = self.llm
device = self.device
batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
cond_start_idx = 0
uncond_start_idx = batch_size
# Initialize generated sequences
generated_ids = batch_input_ids.clone()
if batch_attention_mask is not None:
attention_mask = batch_attention_mask.clone()
else:
attention_mask = torch.ones_like(batch_input_ids)
# Prepare model inputs
model_kwargs = {}
if batch_attention_mask is not None:
model_kwargs['attention_mask'] = attention_mask
# Past key values for KV cache (if model supports it)
past_key_values = None
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
# Get EOS token ID for stopping condition
eos_token_id = self.llm_tokenizer.eos_token_id
if eos_token_id is None:
eos_token_id = pad_token_id
# Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
logits_processor = self._build_logits_processor(repetition_penalty)
with torch.no_grad():
for step in range(max_new_tokens):
# Forward pass for the entire batch (conditional + unconditional)
outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
# Get logits for the last position
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
# Split conditional and unconditional logits
cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
# Apply CFG formula: cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
# Apply constrained processor FIRST (modifies logits based on FSM state)
if constrained_processor is not None:
current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
cfg_logits = constrained_processor(current_input_ids, cfg_logits)
# Apply logits processors (repetition penalty, top-k, top-p)
# Get current input_ids for repetition penalty (only conditional part)
current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
for processor in logits_processor:
cfg_logits = processor(current_input_ids, cfg_logits)
# Apply top-k and top-p filtering
cfg_logits = self._apply_top_k_filter(cfg_logits, top_k)
cfg_logits = self._apply_top_p_filter(cfg_logits, top_p)
# Apply temperature and sample
next_tokens = self._sample_tokens(cfg_logits, temperature)
# Update constrained processor state AFTER sampling
self._update_constrained_processor_state(constrained_processor, next_tokens)
# Check for EOS token in conditional sequences BEFORE unsqueezing
# Stop if any conditional sequence generates EOS token
# next_tokens shape: [batch_size] (only conditional tokens)
should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
# Apply the same sampled tokens to both conditional and unconditional sequences
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed.repeat(2, 1)], dim=1)
attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
model_kwargs['attention_mask'] = attention_mask
# Update past_key_values for next iteration
if use_cache and hasattr(outputs, 'past_key_values'):
past_key_values = outputs.past_key_values
# Update streamer
if streamer is not None:
streamer.put(next_tokens_unsqueezed) # Stream conditional tokens
# Stop generation if EOS token detected
if should_stop:
break
if streamer is not None:
streamer.end()
# Return the full batch (both conditional and unconditional)
# The caller will extract only the conditional output
return generated_ids
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
"""
Parse LM output to extract metadata and audio codes.
Expected format:
<think>
bpm: 73
caption: A calm piano melody
duration: 273
genres: Chinese folk
keyscale: G major
language: en
timesignature: 4
</think>
<|audio_code_56535|><|audio_code_62918|>...
Returns:
Tuple of (metadata_dict, audio_codes_string)
"""
debug_output_text = output_text.split("</think>")[0]
logger.debug(f"Debug output text: {debug_output_text}")
metadata = {}
audio_codes = ""
import re
# Extract audio codes - find all <|audio_code_XXX|> patterns
code_pattern = r'<\|audio_code_\d+\|>'
code_matches = re.findall(code_pattern, output_text)
if code_matches:
audio_codes = "".join(code_matches)
# Extract metadata from reasoning section
# Try different reasoning tag patterns
reasoning_patterns = [
r'<think>(.*?)</think>',
r'<think>(.*?)</think>',
r'<reasoning>(.*?)</reasoning>',
]
reasoning_text = None
for pattern in reasoning_patterns:
match = re.search(pattern, output_text, re.DOTALL)
if match:
reasoning_text = match.group(1).strip()
break
# If no reasoning tags found, try to parse metadata from the beginning of output
if not reasoning_text:
# Look for metadata lines before audio codes
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
reasoning_text = lines_before_codes.strip()
# Parse metadata fields with YAML multi-line value support
if reasoning_text:
lines = reasoning_text.split('\n')
current_key = None
current_value_lines = []
def save_current_field():
"""Save the accumulated field value"""
nonlocal current_key, current_value_lines
if current_key and current_value_lines:
# Join multi-line value
value = '\n'.join(current_value_lines)
if current_key == 'bpm':
try:
metadata['bpm'] = int(value.strip())
except:
metadata['bpm'] = value.strip()
elif current_key == 'caption':
# Post-process caption to remove YAML multi-line formatting
metadata['caption'] = MetadataConstrainedLogitsProcessor.postprocess_caption(value)
elif current_key == 'duration':
try:
metadata['duration'] = int(value.strip())
except:
metadata['duration'] = value.strip()
elif current_key == 'genres':
metadata['genres'] = value.strip()
elif current_key == 'keyscale':
metadata['keyscale'] = value.strip()
elif current_key == 'language':
metadata['language'] = value.strip()
elif current_key == 'timesignature':
metadata['timesignature'] = value.strip()
current_key = None
current_value_lines = []
for line in lines:
# Skip lines starting with '<' (tags)
if line.strip().startswith('<'):
continue
# Check if this is a new field (no leading spaces and contains ':')
if line and not line[0].isspace() and ':' in line:
# Save previous field if any
save_current_field()
# Parse new field
parts = line.split(':', 1)
if len(parts) == 2:
current_key = parts[0].strip().lower()
# First line of value (after colon)
first_value = parts[1]
if first_value.strip():
current_value_lines.append(first_value)
elif line.startswith(' ') or line.startswith('\t'):
# Continuation line (YAML multi-line value)
if current_key:
current_value_lines.append(line)
# Don't forget to save the last field
save_current_field()
return metadata, audio_codes
@contextmanager
def _load_model_context(self):
"""
Context manager to load a model to GPU and offload it back to CPU after use.
Only used for PyTorch backend when offload_to_cpu is True.
"""
if not self.offload_to_cpu:
yield
return
# If using nanovllm, do not offload (it stays on GPU)
if self.llm_backend == "vllm":
yield
return
model = self.llm
if model is None:
yield
return
# Load to GPU
logger.info(f"Loading LLM to {self.device}")
start_time = time.time()
if hasattr(model, "to"):
model.to(self.device).to(self.dtype)
load_time = time.time() - start_time
logger.info(f"Loaded LLM to {self.device} in {load_time:.4f}s")
try:
yield
finally:
# Offload to CPU
logger.info(f"Offloading LLM to CPU")
start_time = time.time()
if hasattr(model, "to"):
model.to("cpu")
torch.cuda.empty_cache()
offload_time = time.time() - start_time
logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s")
def get_hf_model_for_scoring(self):
"""
Get HuggingFace model for perplexity scoring.
For vllm backend, loads HuggingFace model from disk (weights are cached by transformers).
For pt backend, returns the existing model.
Returns:
HuggingFace model instance
"""
if self.llm_backend == "pt":
# For PyTorch backend, directly return the model
return self.llm
elif self.llm_backend == "vllm":
# For vllm backend, load HuggingFace model from disk
# Note: transformers caches model weights, so this doesn't duplicate disk I/O
if self._hf_model_for_scoring is None:
logger.info("Loading HuggingFace model for scoring (from checkpoint)")
# Get model path from vllm config
model_runner = self.llm.model_runner
model_path = model_runner.config.model
# Load HuggingFace model from the same checkpoint
# This will load the original unfused weights
import time
start_time = time.time()
self._hf_model_for_scoring = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype
)
load_time = time.time() - start_time
logger.info(f"HuggingFace model loaded in {load_time:.2f}s")
# Move to same device as vllm model
device = next(model_runner.model.parameters()).device
self._hf_model_for_scoring = self._hf_model_for_scoring.to(device)
self._hf_model_for_scoring.eval()
logger.info(f"HuggingFace model for scoring ready on {device}")
return self._hf_model_for_scoring
else:
raise ValueError(f"Unknown backend: {self.llm_backend}")
|