Spaces:
Running
on
A100
Running
on
A100
File size: 49,230 Bytes
24f370e 59ce525 24f370e 11860f1 24f370e 11860f1 24f370e 59ce525 24f370e 11860f1 24f370e 11860f1 24f370e 11860f1 24f370e 11860f1 24f370e 11860f1 bb87271 24f370e 11860f1 24f370e 11860f1 24f370e 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 24f370e 11860f1 03f73c6 11860f1 24f370e 11860f1 24f370e 11860f1 24f370e bb87271 f4d9d31 bc7e55b 11860f1 24f370e 03f73c6 24f370e 11860f1 24f370e 03f73c6 24f370e 03f73c6 11860f1 24f370e 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 24f370e 11860f1 24f370e 11860f1 24f370e 03f73c6 24f370e 11860f1 24f370e 11860f1 24f370e 03f73c6 24f370e 85c5902 03f73c6 59ce525 24f370e 11860f1 24f370e 11860f1 03f73c6 24f370e 11860f1 24f370e 11860f1 24f370e 11860f1 24f370e 03f73c6 24f370e 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 24f370e bdc442a 3c7cb5d 03f73c6 bdc442a 3c7cb5d 03f73c6 24f370e 03f73c6 24f370e 03f73c6 24f370e 03f73c6 24f370e 03f73c6 24f370e 03f73c6 24f370e 03f73c6 24f370e 03f73c6 11860f1 1e0d19a 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 24f370e 76de6b9 24f370e 11860f1 03f73c6 24f370e 11860f1 03f73c6 24f370e 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 24f370e 11860f1 24f370e 03f73c6 24f370e 03f73c6 11860f1 24f370e 11860f1 24f370e 11860f1 24f370e 11860f1 bb87271 f4d9d31 bc7e55b 03f73c6 24f370e 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 03f73c6 11860f1 24f370e 11860f1 24f370e 11860f1 24f370e 03f73c6 24f370e 11860f1 24f370e 11860f1 24f370e 85c5902 2b1ad1c 85c5902 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c 4a86c5f 2b1ad1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 |
"""
ACE-Step Inference API Module
This module provides a standardized inference interface for music generation,
designed for third-party integration. It offers both a simplified API and
backward-compatible Gradio UI support.
"""
import math
import os
import tempfile
from typing import Optional, Union, List, Dict, Any, Tuple
from dataclasses import dataclass, field, asdict
from loguru import logger
from acestep.audio_utils import AudioSaver, generate_uuid_from_params
# HuggingFace Space environment detection
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
def _get_spaces_gpu_decorator(duration=180):
"""
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
Returns identity decorator if not in Space environment.
"""
if IS_HUGGINGFACE_SPACE:
try:
import spaces
return spaces.GPU(duration=duration)
except ImportError:
logger.warning("spaces package not found, GPU decorator disabled")
return lambda func: func
return lambda func: func
@dataclass
class GenerationParams:
"""Configuration for music generation parameters.
Attributes:
# Text Inputs
caption: A short text prompt describing the desired music (main prompt). < 512 characters
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
instrumental: If True, generate instrumental music regardless of lyrics.
# Music Metadata
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
# Generation Parameters
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
seed: Integer seed for reproducibility. -1 means use random seed each time.
# Advanced DiT Parameters
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
# Task-Specific Parameters
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
reference_audio: Path to a reference audio file for style transfer or cover tasks.
src_audio: Path to a source audio file for audio-to-audio tasks.
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
instruction: Optional task instruction prompt. If empty, auto-generated by system.
# 5Hz Language Model Parameters for CoT reasoning
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
lm_cfg_scale: Classifier-free guidance scale for the LLM.
lm_top_k: LLM top-k sampling (0 = disabled).
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
lm_negative_prompt: Negative prompt to use for LLM (for control).
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
use_cot_language: Whether to let LLM detect vocal language via CoT.
"""
# Required Inputs
task_type: str = "text2music"
instruction: str = "Fill the audio semantic mask based on the given conditions:"
# Audio Uploads
reference_audio: Optional[str] = None
src_audio: Optional[str] = None
# LM Codes Hints
audio_codes: str = ""
# Text Inputs
caption: str = ""
lyrics: str = ""
instrumental: bool = False
# Metadata
vocal_language: str = "unknown"
bpm: Optional[int] = None
keyscale: str = ""
timesignature: str = ""
duration: float = -1.0
# Advanced Settings
inference_steps: int = 8
seed: int = -1
guidance_scale: float = 7.0
use_adg: bool = False
cfg_interval_start: float = 0.0
cfg_interval_end: float = 1.0
shift: float = 1.0
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
# If provided, overrides inference_steps and shift
timesteps: Optional[List[float]] = None
repainting_start: float = 0.0
repainting_end: float = -1
audio_cover_strength: float = 1.0
# 5Hz Language Model Parameters
thinking: bool = True
lm_temperature: float = 0.85
lm_cfg_scale: float = 2.0
lm_top_k: int = 0
lm_top_p: float = 0.9
lm_negative_prompt: str = "NO USER INPUT"
use_cot_metas: bool = True
use_cot_caption: bool = True
use_cot_lyrics: bool = False # TODO: not used yet
use_cot_language: bool = True
use_constrained_decoding: bool = True
cot_bpm: Optional[int] = None
cot_keyscale: str = ""
cot_timesignature: str = ""
cot_duration: Optional[float] = None
cot_vocal_language: str = "unknown"
cot_caption: str = ""
cot_lyrics: str = ""
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary for JSON serialization."""
return asdict(self)
@dataclass
class GenerationConfig:
"""Configuration for music generation.
Attributes:
batch_size: Number of audio samples to generate
allow_lm_batch: Whether to allow batch processing in LM
use_random_seed: Whether to use random seed
seeds: Seed(s) for batch generation. Can be:
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
- int: Single seed value (will be converted to list and padded)
lm_batch_chunk_size: Batch chunk size for LM processing
constrained_decoding_debug: Whether to enable constrained decoding debug
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
"""
batch_size: int = 2
allow_lm_batch: bool = False
use_random_seed: bool = True
seeds: Optional[List[int]] = None
lm_batch_chunk_size: int = 8
constrained_decoding_debug: bool = False
audio_format: str = "flac" # Default to FLAC for fast saving
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary for JSON serialization."""
return asdict(self)
@dataclass
class GenerationResult:
"""Result of music generation.
Attributes:
# Audio Outputs
audios: List of audio dictionaries with paths, keys, params
status_message: Status message from generation
extra_outputs: Extra outputs from generation
success: Whether generation completed successfully
error: Error message if generation failed
"""
# Audio Outputs
audios: List[Dict[str, Any]] = field(default_factory=list)
# Generation Information
status_message: str = ""
extra_outputs: Dict[str, Any] = field(default_factory=dict)
# Success Status
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
@dataclass
class UnderstandResult:
"""Result of music understanding from audio codes.
Attributes:
# Metadata Fields
caption: Generated caption describing the music
lyrics: Generated or extracted lyrics
bpm: Beats per minute (None if not detected)
duration: Duration in seconds (None if not detected)
keyscale: Musical key (e.g., "C Major")
language: Vocal language code (e.g., "en", "zh")
timesignature: Time signature (e.g., "4/4")
# Status
status_message: Status message from understanding
success: Whether understanding completed successfully
error: Error message if understanding failed
"""
# Metadata Fields
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
duration: Optional[float] = None
keyscale: str = ""
language: str = ""
timesignature: str = ""
# Status
status_message: str = ""
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
def _update_metadata_from_lm(
metadata: Dict[str, Any],
bpm: Optional[int],
key_scale: str,
time_signature: str,
audio_duration: Optional[float],
vocal_language: str,
caption: str,
lyrics: str,
) -> Tuple[Optional[int], str, str, Optional[float]]:
"""Update metadata fields from LM output if not provided by user."""
if bpm is None and metadata.get('bpm'):
bpm_value = metadata.get('bpm')
if bpm_value not in ["N/A", ""]:
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
if not key_scale and metadata.get('keyscale'):
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
if key_scale_value != "N/A":
key_scale = key_scale_value
if not time_signature and metadata.get('timesignature'):
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
if time_signature_value != "N/A":
time_signature = time_signature_value
if audio_duration is None or audio_duration <= 0:
audio_duration_value = metadata.get('duration', -1)
if audio_duration_value not in ["N/A", ""]:
try:
audio_duration = float(audio_duration_value)
except (ValueError, TypeError):
pass
if not vocal_language and metadata.get('vocal_language'):
vocal_language = metadata.get('vocal_language')
if not caption and metadata.get('caption'):
caption = metadata.get('caption')
if not lyrics and metadata.get('lyrics'):
lyrics = metadata.get('lyrics')
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
@_get_spaces_gpu_decorator(duration=180)
def generate_music(
dit_handler,
llm_handler,
params: GenerationParams,
config: GenerationConfig,
save_dir: Optional[str] = None,
progress=None,
) -> GenerationResult:
"""Generate music using ACE-Step model with optional LM reasoning.
Args:
dit_handler: Initialized DiT model handler (AceStepHandler instance)
llm_handler: Initialized LLM handler (LLMHandler instance)
params: Generation parameters (GenerationParams instance)
config: Generation configuration (GenerationConfig instance)
Returns:
GenerationResult with generated audio files and metadata
"""
try:
# Phase 1: LM-based metadata and code generation (if enabled)
audio_code_string_to_use = params.audio_codes
lm_generated_metadata = None
lm_generated_audio_codes_list = []
lm_total_time_costs = {
"phase1_time": 0.0,
"phase2_time": 0.0,
"total_time": 0.0,
}
# Extract mutable copies of metadata (will be updated by LM if needed)
bpm = params.bpm
key_scale = params.keyscale
time_signature = params.timesignature
audio_duration = params.duration
dit_input_caption = params.caption
dit_input_vocal_language = params.vocal_language
dit_input_lyrics = params.lyrics
# Determine if we need to generate audio codes
# If user has provided audio_codes, we don't need to generate them
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
# Note: This logic can be refined based on specific requirements
need_audio_codes = not user_provided_audio_codes
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
# Determine actual batch size for chunk processing
actual_batch_size = config.batch_size if config.batch_size is not None else 1
# Prepare seeds for batch generation
# Use config.seed if provided, otherwise fallback to params.seed
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
seed_for_generation = ""
if config.seeds is not None and len(config.seeds) > 0:
if isinstance(config.seeds, list):
# Convert List[int] to comma-separated string
seed_for_generation = ",".join(str(s) for s in config.seeds)
# Use dit_handler.prepare_seeds to handle seed list generation and padding
# This will handle all the logic: padding with random seeds if needed, etc.
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
# LM-based Chain-of-Thought reasoning
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
# and don't need LM to generate audio codes
skip_lm_tasks = {"cover", "repaint"}
# Determine if we should use LLM
# LLM is needed for:
# 1. thinking=True: generate audio codes via LM
# 2. use_cot_caption=True: enhance/generate caption via CoT
# 3. use_cot_language=True: detect vocal language via CoT
# 4. use_cot_metas=True: fill missing metadata via CoT
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
lm_status = []
if params.task_type in skip_lm_tasks:
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
if use_lm:
# Convert sampling parameters - handle None values safely
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
# Build user_metadata from user-provided values
user_metadata = {}
if bpm is not None:
try:
bpm_value = float(bpm)
if bpm_value > 0:
user_metadata['bpm'] = int(bpm_value)
except (ValueError, TypeError):
pass
if key_scale and key_scale.strip():
key_scale_clean = key_scale.strip()
if key_scale_clean.lower() not in ["n/a", ""]:
user_metadata['keyscale'] = key_scale_clean
if time_signature and time_signature.strip():
time_sig_clean = time_signature.strip()
if time_sig_clean.lower() not in ["n/a", ""]:
user_metadata['timesignature'] = time_sig_clean
if audio_duration is not None:
try:
duration_value = float(audio_duration)
if duration_value > 0:
user_metadata['duration'] = int(duration_value)
except (ValueError, TypeError):
pass
user_metadata_to_pass = user_metadata if user_metadata else None
# Determine infer_type based on whether we need audio codes
# - "llm_dit": generates both metas and audio codes (two-phase internally)
# - "dit": generates only metas (single phase)
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
# Use chunk size from config, or default to batch_size if not set
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
all_metadata_list = []
all_audio_codes_list = []
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * max_inference_batch_size
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
chunk_size = chunk_end - chunk_start
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
f"(size: {chunk_size}, seeds: {chunk_seeds})")
# Use the determined infer_type
# - "llm_dit" will internally run two phases (metas + codes)
# - "dit" will only run phase 1 (metas only)
result = llm_handler.generate_with_stop_condition(
caption=params.caption or "",
lyrics=params.lyrics or "",
infer_type=infer_type,
temperature=params.lm_temperature,
cfg_scale=params.lm_cfg_scale,
negative_prompt=params.lm_negative_prompt,
top_k=top_k_value,
top_p=top_p_value,
target_duration=audio_duration, # Pass duration to limit audio codes generation
user_metadata=user_metadata_to_pass,
use_cot_caption=params.use_cot_caption,
use_cot_language=params.use_cot_language,
use_cot_metas=params.use_cot_metas,
use_constrained_decoding=params.use_constrained_decoding,
constrained_decoding_debug=config.constrained_decoding_debug,
batch_size=chunk_size,
seeds=chunk_seeds,
progress=progress,
)
# Check if LM generation failed
if not result.get("success", False):
error_msg = result.get("error", "Unknown LM error")
lm_status.append(f"❌ LM Error: {error_msg}")
# Return early with error
return GenerationResult(
audios=[],
status_message=f"❌ LM generation failed: {error_msg}",
extra_outputs={},
success=False,
error=error_msg,
)
# Extract metadata and audio_codes from result dict
if chunk_size > 1:
metadata_list = result.get("metadata", [])
audio_codes_list = result.get("audio_codes", [])
all_metadata_list.extend(metadata_list)
all_audio_codes_list.extend(audio_codes_list)
else:
metadata = result.get("metadata", {})
audio_codes = result.get("audio_codes", "")
all_metadata_list.append(metadata)
all_audio_codes_list.append(audio_codes)
# Collect time costs from LM extra_outputs
lm_extra = result.get("extra_outputs", {})
lm_chunk_time_costs = lm_extra.get("time_costs", {})
if lm_chunk_time_costs:
# Accumulate time costs from all chunks
for key in ["phase1_time", "phase2_time", "total_time"]:
if key in lm_chunk_time_costs:
lm_total_time_costs[key] += lm_chunk_time_costs[key]
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
lm_generated_audio_codes_list = all_audio_codes_list
# Set audio_code_string_to_use based on infer_type
if infer_type == "llm_dit":
# If batch mode, use list; otherwise use single string
if actual_batch_size > 1:
audio_code_string_to_use = all_audio_codes_list
else:
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
else:
# For "dit" mode, keep user-provided codes or empty
audio_code_string_to_use = params.audio_codes
# Update metadata from LM if not provided by user
if lm_generated_metadata:
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
metadata=lm_generated_metadata,
bpm=bpm,
key_scale=key_scale,
time_signature=time_signature,
audio_duration=audio_duration,
vocal_language=dit_input_vocal_language,
caption=dit_input_caption,
lyrics=dit_input_lyrics)
if not params.bpm:
params.cot_bpm = bpm
if not params.keyscale:
params.cot_keyscale = key_scale
if not params.timesignature:
params.cot_timesignature = time_signature
if not params.duration:
params.cot_duration = audio_duration
if not params.vocal_language:
params.cot_vocal_language = vocal_language
if not params.caption:
params.cot_caption = caption
if not params.lyrics:
params.cot_lyrics = lyrics
# set cot caption and language if needed
if params.use_cot_caption:
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
if params.use_cot_language:
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
# Phase 2: DiT music generation
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
result = dit_handler.generate_music(
captions=dit_input_caption,
lyrics=dit_input_lyrics,
bpm=bpm,
key_scale=key_scale,
time_signature=time_signature,
vocal_language=dit_input_vocal_language,
inference_steps=params.inference_steps,
guidance_scale=params.guidance_scale,
use_random_seed=config.use_random_seed,
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
reference_audio=params.reference_audio,
audio_duration=audio_duration,
batch_size=config.batch_size if config.batch_size is not None else 1,
src_audio=params.src_audio,
audio_code_string=audio_code_string_to_use,
repainting_start=params.repainting_start,
repainting_end=params.repainting_end,
instruction=params.instruction,
audio_cover_strength=params.audio_cover_strength,
task_type=params.task_type,
use_adg=params.use_adg,
cfg_interval_start=params.cfg_interval_start,
cfg_interval_end=params.cfg_interval_end,
shift=params.shift,
infer_method=params.infer_method,
timesteps=params.timesteps,
progress=progress,
)
# Check if generation failed
if not result.get("success", False):
return GenerationResult(
audios=[],
status_message=result.get("status_message", ""),
extra_outputs={},
success=False,
error=result.get("error"),
)
# Extract results from dit_handler.generate_music dict
dit_audios = result.get("audios", [])
status_message = result.get("status_message", "")
dit_extra_outputs = result.get("extra_outputs", {})
# Use the seed list already prepared above (from config.seed or params.seed fallback)
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
seed_list = actual_seed_list
# Get base params dictionary
base_params_dict = params.to_dict()
# Save audio files using AudioSaver (format from config)
audio_format = config.audio_format if config.audio_format else "flac"
audio_saver = AudioSaver(default_format=audio_format)
# Use handler's temp_dir for saving files
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
# Build audios list for GenerationResult with params and save files
# Audio saving and UUID generation handled here, outside of handler
audios = []
for idx, dit_audio in enumerate(dit_audios):
# Create a copy of params dict for this audio
audio_params = base_params_dict.copy()
# Update audio-specific values
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
# Add audio codes if batch mode
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
# Get audio tensor and metadata
audio_tensor = dit_audio.get("tensor")
sample_rate = dit_audio.get("sample_rate", 48000)
# Generate UUID for this audio (moved from handler)
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
audio_code_str = lm_generated_audio_codes_list[idx] if (
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
if isinstance(audio_code_str, list):
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
audio_key = generate_uuid_from_params(audio_params)
# Save audio file (handled outside handler)
audio_path = None
if audio_tensor is not None and save_dir is not None:
try:
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
audio_path = audio_saver.save_audio(audio_tensor,
audio_file,
sample_rate=sample_rate,
format=audio_format,
channels_first=True)
except Exception as e:
logger.error(f"[generate_music] Failed to save audio file: {e}")
audio_path = "" # Fallback to empty path
audio_dict = {
"path": audio_path or "", # File path (saved here, not in handler)
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
"key": audio_key,
"sample_rate": sample_rate,
"params": audio_params,
}
audios.append(audio_dict)
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
extra_outputs = dit_extra_outputs.copy()
extra_outputs["lm_metadata"] = lm_generated_metadata
# Merge time_costs from both LM and DiT into a unified dictionary
unified_time_costs = {}
# Add LM time costs (if LM was used)
if use_lm and lm_total_time_costs:
for key, value in lm_total_time_costs.items():
unified_time_costs[f"lm_{key}"] = value
# Add DiT time costs (if available)
dit_time_costs = dit_extra_outputs.get("time_costs", {})
if dit_time_costs:
for key, value in dit_time_costs.items():
unified_time_costs[f"dit_{key}"] = value
# Calculate total pipeline time
if unified_time_costs:
lm_total = unified_time_costs.get("lm_total_time", 0.0)
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
# Update extra_outputs with unified time_costs
extra_outputs["time_costs"] = unified_time_costs
if lm_status:
status_message = "\n".join(lm_status) + "\n" + status_message
else:
status_message = status_message
# Create and return GenerationResult
return GenerationResult(
audios=audios,
status_message=status_message,
extra_outputs=extra_outputs,
success=True,
error=None,
)
except Exception as e:
logger.exception("Music generation failed")
return GenerationResult(
audios=[],
status_message=f"Error: {str(e)}",
extra_outputs={},
success=False,
error=str(e),
)
def understand_music(
llm_handler,
audio_codes: str,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> UnderstandResult:
"""Understand music from audio codes using the 5Hz Language Model.
This function analyzes audio semantic codes and generates metadata about the music,
including caption, lyrics, BPM, duration, key scale, language, and time signature.
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
instead of analyzing existing codes.
Note: cfg_scale and negative_prompt are not supported in understand mode.
Args:
llm_handler: Initialized LLM handler (LLMHandler instance)
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
Use empty string or "NO USER INPUT" to generate a sample example.
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
top_k: Top-K sampling (None or 0 = disabled)
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
Returns:
UnderstandResult with parsed metadata fields and status
Example:
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
>>> if result.success:
... print(f"Caption: {result.caption}")
... print(f"BPM: {result.bpm}")
... print(f"Lyrics: {result.lyrics}")
"""
# Check if LLM is initialized
if not llm_handler.llm_initialized:
return UnderstandResult(
status_message="5Hz LM not initialized. Please initialize it first.",
success=False,
error="LLM not initialized",
)
# If codes are empty, use "NO USER INPUT" to generate a sample example
if not audio_codes or not audio_codes.strip():
audio_codes = "NO USER INPUT"
try:
# Call LLM understanding
metadata, status = llm_handler.understand_audio_from_codes(
audio_codes=audio_codes,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
)
# Check if LLM returned empty metadata (error case)
if not metadata:
return UnderstandResult(
status_message=status or "Failed to understand audio codes",
success=False,
error=status or "Empty metadata returned",
)
# Extract and convert fields
caption = metadata.get('caption', '')
lyrics = metadata.get('lyrics', '')
keyscale = metadata.get('keyscale', '')
language = metadata.get('language', metadata.get('vocal_language', ''))
timesignature = metadata.get('timesignature', '')
# Convert BPM to int
bpm = None
bpm_value = metadata.get('bpm')
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
# Convert duration to float
duration = None
duration_value = metadata.get('duration')
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
try:
duration = float(duration_value)
except (ValueError, TypeError):
pass
# Clean up N/A values
if keyscale == 'N/A':
keyscale = ''
if language == 'N/A':
language = ''
if timesignature == 'N/A':
timesignature = ''
return UnderstandResult(
caption=caption,
lyrics=lyrics,
bpm=bpm,
duration=duration,
keyscale=keyscale,
language=language,
timesignature=timesignature,
status_message=status,
success=True,
error=None,
)
except Exception as e:
logger.exception("Music understanding failed")
return UnderstandResult(
status_message=f"Error: {str(e)}",
success=False,
error=str(e),
)
@dataclass
class CreateSampleResult:
"""Result of creating a music sample from a natural language query.
This is used by the "Simple Mode" / "Inspiration Mode" feature where users
provide a natural language description and the LLM generates a complete
sample with caption, lyrics, and metadata.
Attributes:
# Metadata Fields
caption: Generated detailed music description/caption
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
bpm: Beats per minute (None if not generated)
duration: Duration in seconds (None if not generated)
keyscale: Musical key (e.g., "C Major")
language: Vocal language code (e.g., "en", "zh")
timesignature: Time signature (e.g., "4")
instrumental: Whether this is an instrumental piece
# Status
status_message: Status message from sample creation
success: Whether sample creation completed successfully
error: Error message if sample creation failed
"""
# Metadata Fields
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
duration: Optional[float] = None
keyscale: str = ""
language: str = ""
timesignature: str = ""
instrumental: bool = False
# Status
status_message: str = ""
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
def create_sample(
llm_handler,
query: str,
instrumental: bool = False,
vocal_language: Optional[str] = None,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> CreateSampleResult:
"""Create a music sample from a natural language query using the 5Hz Language Model.
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
language description of music and generates a complete sample including:
- Detailed caption/description
- Lyrics (unless instrumental)
- Metadata (BPM, duration, key, language, time signature)
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
Args:
llm_handler: Initialized LLM handler (LLMHandler instance)
query: User's natural language music description (e.g., "a soft Bengali love song")
instrumental: Whether to generate instrumental music (no vocals)
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
If provided, the model will be constrained to generate lyrics in this language.
If None or "unknown", no language constraint is applied.
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
top_k: Top-K sampling (None or 0 = disabled)
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding
constrained_decoding_debug: Whether to enable debug logging
Returns:
CreateSampleResult with generated sample fields and status
Example:
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
>>> if result.success:
... print(f"Caption: {result.caption}")
... print(f"Lyrics: {result.lyrics}")
... print(f"BPM: {result.bpm}")
"""
# Check if LLM is initialized
if not llm_handler.llm_initialized:
return CreateSampleResult(
status_message="5Hz LM not initialized. Please initialize it first.",
success=False,
error="LLM not initialized",
)
try:
# Call LLM to create sample
metadata, status = llm_handler.create_sample_from_query(
query=query,
instrumental=instrumental,
vocal_language=vocal_language,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
)
# Check if LLM returned empty metadata (error case)
if not metadata:
return CreateSampleResult(
status_message=status or "Failed to create sample",
success=False,
error=status or "Empty metadata returned",
)
# Extract and convert fields
caption = metadata.get('caption', '')
lyrics = metadata.get('lyrics', '')
keyscale = metadata.get('keyscale', '')
language = metadata.get('language', metadata.get('vocal_language', ''))
timesignature = metadata.get('timesignature', '')
is_instrumental = metadata.get('instrumental', instrumental)
# Convert BPM to int
bpm = None
bpm_value = metadata.get('bpm')
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
# Convert duration to float
duration = None
duration_value = metadata.get('duration')
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
try:
duration = float(duration_value)
except (ValueError, TypeError):
pass
# Clean up N/A values
if keyscale == 'N/A':
keyscale = ''
if language == 'N/A':
language = ''
if timesignature == 'N/A':
timesignature = ''
return CreateSampleResult(
caption=caption,
lyrics=lyrics,
bpm=bpm,
duration=duration,
keyscale=keyscale,
language=language,
timesignature=timesignature,
instrumental=is_instrumental,
status_message=status,
success=True,
error=None,
)
except Exception as e:
logger.exception("Sample creation failed")
return CreateSampleResult(
status_message=f"Error: {str(e)}",
success=False,
error=str(e),
)
@dataclass
class FormatSampleResult:
"""Result of formatting user-provided caption and lyrics.
This is used by the "Format" feature where users provide caption and lyrics,
and the LLM formats them into structured music metadata and an enhanced description.
Attributes:
# Metadata Fields
caption: Enhanced/formatted music description/caption
lyrics: Formatted lyrics (may be same as input or reformatted)
bpm: Beats per minute (None if not detected)
duration: Duration in seconds (None if not detected)
keyscale: Musical key (e.g., "C Major")
language: Vocal language code (e.g., "en", "zh")
timesignature: Time signature (e.g., "4")
# Status
status_message: Status message from formatting
success: Whether formatting completed successfully
error: Error message if formatting failed
"""
# Metadata Fields
caption: str = ""
lyrics: str = ""
bpm: Optional[int] = None
duration: Optional[float] = None
keyscale: str = ""
language: str = ""
timesignature: str = ""
# Status
status_message: str = ""
success: bool = True
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for JSON serialization."""
return asdict(self)
def format_sample(
llm_handler,
caption: str,
lyrics: str,
user_metadata: Optional[Dict[str, Any]] = None,
temperature: float = 0.85,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_constrained_decoding: bool = True,
constrained_decoding_debug: bool = False,
) -> FormatSampleResult:
"""Format user-provided caption and lyrics using the 5Hz Language Model.
This function takes user input (caption and lyrics) and generates structured
music metadata including an enhanced caption, BPM, duration, key, language,
and time signature.
If user_metadata is provided, those values will be used to constrain the
decoding, ensuring the output matches user-specified values.
Note: cfg_scale and negative_prompt are not supported in format mode.
Args:
llm_handler: Initialized LLM handler (LLMHandler instance)
caption: User's caption/description (e.g., "Latin pop, reggaeton")
lyrics: User's lyrics with structure tags
user_metadata: Optional dict with user-provided metadata to constrain decoding.
Supported keys: bpm, duration, keyscale, timesignature, language
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
top_k: Top-K sampling (None or 0 = disabled)
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
repetition_penalty: Repetition penalty (1.0 = no penalty)
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
Returns:
FormatSampleResult with formatted metadata fields and status
Example:
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
>>> if result.success:
... print(f"Caption: {result.caption}")
... print(f"BPM: {result.bpm}")
... print(f"Lyrics: {result.lyrics}")
"""
# Check if LLM is initialized
if not llm_handler.llm_initialized:
return FormatSampleResult(
status_message="5Hz LM not initialized. Please initialize it first.",
success=False,
error="LLM not initialized",
)
try:
# Call LLM formatting
metadata, status = llm_handler.format_sample_from_input(
caption=caption,
lyrics=lyrics,
user_metadata=user_metadata,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
use_constrained_decoding=use_constrained_decoding,
constrained_decoding_debug=constrained_decoding_debug,
)
# Check if LLM returned empty metadata (error case)
if not metadata:
return FormatSampleResult(
status_message=status or "Failed to format input",
success=False,
error=status or "Empty metadata returned",
)
# Extract and convert fields
result_caption = metadata.get('caption', '')
result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
keyscale = metadata.get('keyscale', '')
language = metadata.get('language', metadata.get('vocal_language', ''))
timesignature = metadata.get('timesignature', '')
# Convert BPM to int
bpm = None
bpm_value = metadata.get('bpm')
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
try:
bpm = int(bpm_value)
except (ValueError, TypeError):
pass
# Convert duration to float
duration = None
duration_value = metadata.get('duration')
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
try:
duration = float(duration_value)
except (ValueError, TypeError):
pass
# Clean up N/A values
if keyscale == 'N/A':
keyscale = ''
if language == 'N/A':
language = ''
if timesignature == 'N/A':
timesignature = ''
return FormatSampleResult(
caption=result_caption,
lyrics=result_lyrics,
bpm=bpm,
duration=duration,
keyscale=keyscale,
language=language,
timesignature=timesignature,
status_message=status,
success=True,
error=None,
)
except Exception as e:
logger.exception("Format sample failed")
return FormatSampleResult(
status_message=f"Error: {str(e)}",
success=False,
error=str(e),
)
|