Spaces:
Running
on
A100
Running
on
A100
refact handler
Browse files- .gitignore +2 -0
- acestep/acestep_v15_pipeline.py +132 -7
- acestep/gradio_ui.py +161 -55
- acestep/handler.py +3 -2
- acestep/llm_inference.py +234 -12
- acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +100 -9
- acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +3 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +60 -1
- acestep/third_parts/nano-vllm/nanovllm/sampling_params.py +9 -0
.gitignore
CHANGED
|
@@ -213,3 +213,5 @@ tests/
|
|
| 213 |
checkpoints/
|
| 214 |
playground.ipynb
|
| 215 |
.history/
|
|
|
|
|
|
|
|
|
| 213 |
checkpoints/
|
| 214 |
playground.ipynb
|
| 215 |
.history/
|
| 216 |
+
upload_checkpoints.sh
|
| 217 |
+
checkpoints.7z
|
acestep/acestep_v15_pipeline.py
CHANGED
|
@@ -15,20 +15,33 @@ from .dataset_handler import DatasetHandler
|
|
| 15 |
from .gradio_ui import create_gradio_interface
|
| 16 |
|
| 17 |
|
| 18 |
-
def create_demo():
|
| 19 |
"""
|
| 20 |
Create Gradio demo interface
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
Returns:
|
| 23 |
Gradio Blocks instance
|
| 24 |
"""
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
dataset_handler = DatasetHandler() # Dataset handler
|
| 29 |
|
| 30 |
-
# Create Gradio interface with all handlers
|
| 31 |
-
demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler)
|
| 32 |
|
| 33 |
return demo
|
| 34 |
|
|
@@ -42,12 +55,124 @@ def main():
|
|
| 42 |
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 43 |
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 44 |
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
args = parser.parse_args()
|
| 46 |
|
| 47 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Create and launch demo
|
| 49 |
print("Creating Gradio interface...")
|
| 50 |
-
demo = create_demo()
|
| 51 |
print(f"Launching server on {args.server_name}:{args.port}...")
|
| 52 |
demo.launch(
|
| 53 |
server_name=args.server_name,
|
|
|
|
| 15 |
from .gradio_ui import create_gradio_interface
|
| 16 |
|
| 17 |
|
| 18 |
+
def create_demo(init_params=None):
|
| 19 |
"""
|
| 20 |
Create Gradio demo interface
|
| 21 |
|
| 22 |
+
Args:
|
| 23 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 24 |
+
If None, service will not be pre-initialized.
|
| 25 |
+
Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
|
| 26 |
+
'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
|
| 27 |
+
'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
|
| 28 |
+
'dit_handler', 'llm_handler' (initialized handlers if pre-initialized)
|
| 29 |
+
|
| 30 |
Returns:
|
| 31 |
Gradio Blocks instance
|
| 32 |
"""
|
| 33 |
+
# Use pre-initialized handlers if available, otherwise create new ones
|
| 34 |
+
if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
|
| 35 |
+
dit_handler = init_params['dit_handler']
|
| 36 |
+
llm_handler = init_params['llm_handler']
|
| 37 |
+
else:
|
| 38 |
+
dit_handler = AceStepHandler() # DiT handler
|
| 39 |
+
llm_handler = LLMHandler() # LM handler
|
| 40 |
+
|
| 41 |
dataset_handler = DatasetHandler() # Dataset handler
|
| 42 |
|
| 43 |
+
# Create Gradio interface with all handlers and initialization parameters
|
| 44 |
+
demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params)
|
| 45 |
|
| 46 |
return demo
|
| 47 |
|
|
|
|
| 55 |
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 56 |
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 57 |
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
|
| 58 |
+
|
| 59 |
+
# Service initialization arguments
|
| 60 |
+
parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
|
| 61 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
|
| 62 |
+
parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
|
| 63 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
|
| 64 |
+
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
|
| 65 |
+
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
|
| 66 |
+
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
|
| 67 |
+
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
|
| 68 |
+
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload models to CPU (default: False)")
|
| 69 |
+
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
|
| 70 |
+
|
| 71 |
args = parser.parse_args()
|
| 72 |
|
| 73 |
try:
|
| 74 |
+
init_params = None
|
| 75 |
+
|
| 76 |
+
# If init_service is True, perform initialization before creating UI
|
| 77 |
+
if args.init_service:
|
| 78 |
+
print("Initializing service from command line...")
|
| 79 |
+
|
| 80 |
+
# Create handler instances for initialization
|
| 81 |
+
dit_handler = AceStepHandler()
|
| 82 |
+
llm_handler = LLMHandler()
|
| 83 |
+
|
| 84 |
+
# Auto-select config_path if not provided
|
| 85 |
+
if args.config_path is None:
|
| 86 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 87 |
+
if available_models:
|
| 88 |
+
args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
|
| 89 |
+
print(f"Auto-selected config_path: {args.config_path}")
|
| 90 |
+
else:
|
| 91 |
+
print("Error: No available models found. Please specify --config_path", file=sys.stderr)
|
| 92 |
+
sys.exit(1)
|
| 93 |
+
|
| 94 |
+
# Get project root (same logic as in handler)
|
| 95 |
+
current_file = os.path.abspath(__file__)
|
| 96 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 97 |
+
|
| 98 |
+
# Determine flash attention setting
|
| 99 |
+
use_flash_attention = args.use_flash_attention
|
| 100 |
+
if use_flash_attention is None:
|
| 101 |
+
use_flash_attention = dit_handler.is_flash_attention_available()
|
| 102 |
+
|
| 103 |
+
# Initialize DiT handler
|
| 104 |
+
print(f"Initializing DiT model: {args.config_path} on {args.device}...")
|
| 105 |
+
init_status, enable_generate = dit_handler.initialize_service(
|
| 106 |
+
project_root=project_root,
|
| 107 |
+
config_path=args.config_path,
|
| 108 |
+
device=args.device,
|
| 109 |
+
use_flash_attention=use_flash_attention,
|
| 110 |
+
compile_model=False,
|
| 111 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 112 |
+
offload_dit_to_cpu=args.offload_dit_to_cpu
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not enable_generate:
|
| 116 |
+
print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
|
| 117 |
+
sys.exit(1)
|
| 118 |
+
|
| 119 |
+
print(f"DiT model initialized successfully")
|
| 120 |
+
|
| 121 |
+
# Initialize LM handler if requested
|
| 122 |
+
lm_status = ""
|
| 123 |
+
if args.init_llm:
|
| 124 |
+
if args.lm_model_path is None:
|
| 125 |
+
# Try to get default LM model
|
| 126 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 127 |
+
if available_lm_models:
|
| 128 |
+
args.lm_model_path = available_lm_models[0]
|
| 129 |
+
print(f"Using default LM model: {args.lm_model_path}")
|
| 130 |
+
else:
|
| 131 |
+
print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
|
| 132 |
+
args.init_llm = False
|
| 133 |
+
|
| 134 |
+
if args.init_llm and args.lm_model_path:
|
| 135 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 136 |
+
print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
|
| 137 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 138 |
+
checkpoint_dir=checkpoint_dir,
|
| 139 |
+
lm_model_path=args.lm_model_path,
|
| 140 |
+
backend=args.backend,
|
| 141 |
+
device=args.device,
|
| 142 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 143 |
+
dtype=dit_handler.dtype
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if lm_success:
|
| 147 |
+
print(f"5Hz LM initialized successfully")
|
| 148 |
+
init_status += f"\n{lm_status}"
|
| 149 |
+
else:
|
| 150 |
+
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 151 |
+
init_status += f"\n{lm_status}"
|
| 152 |
+
|
| 153 |
+
# Prepare initialization parameters for UI
|
| 154 |
+
init_params = {
|
| 155 |
+
'pre_initialized': True,
|
| 156 |
+
'checkpoint': args.checkpoint,
|
| 157 |
+
'config_path': args.config_path,
|
| 158 |
+
'device': args.device,
|
| 159 |
+
'init_llm': args.init_llm,
|
| 160 |
+
'lm_model_path': args.lm_model_path,
|
| 161 |
+
'backend': args.backend,
|
| 162 |
+
'use_flash_attention': use_flash_attention,
|
| 163 |
+
'offload_to_cpu': args.offload_to_cpu,
|
| 164 |
+
'offload_dit_to_cpu': args.offload_dit_to_cpu,
|
| 165 |
+
'init_status': init_status,
|
| 166 |
+
'enable_generate': enable_generate,
|
| 167 |
+
'dit_handler': dit_handler,
|
| 168 |
+
'llm_handler': llm_handler
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
print("Service initialization completed successfully!")
|
| 172 |
+
|
| 173 |
# Create and launch demo
|
| 174 |
print("Creating Gradio interface...")
|
| 175 |
+
demo = create_demo(init_params=init_params)
|
| 176 |
print(f"Launching server on {args.server_name}:{args.port}...")
|
| 177 |
demo.launch(
|
| 178 |
server_name=args.server_name,
|
acestep/gradio_ui.py
CHANGED
|
@@ -7,7 +7,7 @@ import gradio as gr
|
|
| 7 |
from typing import Callable, Optional
|
| 8 |
|
| 9 |
|
| 10 |
-
def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blocks:
|
| 11 |
"""
|
| 12 |
Create Gradio interface
|
| 13 |
|
|
@@ -15,6 +15,8 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blo
|
|
| 15 |
dit_handler: DiT handler instance
|
| 16 |
llm_handler: LM handler instance
|
| 17 |
dataset_handler: Dataset handler instance
|
|
|
|
|
|
|
| 18 |
|
| 19 |
Returns:
|
| 20 |
Gradio Blocks instance
|
|
@@ -47,8 +49,8 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blo
|
|
| 47 |
# Dataset Explorer Section
|
| 48 |
dataset_section = create_dataset_section(dataset_handler)
|
| 49 |
|
| 50 |
-
# Generation Section
|
| 51 |
-
generation_section = create_generation_section(dit_handler, llm_handler)
|
| 52 |
|
| 53 |
# Results Section
|
| 54 |
results_section = create_results_section(dit_handler)
|
|
@@ -156,20 +158,33 @@ def create_dataset_section(dataset_handler) -> dict:
|
|
| 156 |
}
|
| 157 |
|
| 158 |
|
| 159 |
-
def create_generation_section(dit_handler, llm_handler) -> dict:
|
| 160 |
-
"""Create generation section
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
with gr.Group():
|
| 162 |
gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
|
| 163 |
|
| 164 |
-
# Service Configuration
|
| 165 |
-
|
|
|
|
| 166 |
# Dropdown options section - all dropdowns grouped together
|
| 167 |
with gr.Row(equal_height=True):
|
| 168 |
with gr.Column(scale=4):
|
|
|
|
|
|
|
| 169 |
checkpoint_dropdown = gr.Dropdown(
|
| 170 |
label="Checkpoint File",
|
| 171 |
choices=dit_handler.get_available_checkpoints(),
|
| 172 |
-
value=
|
| 173 |
info="Select a trained model checkpoint file (full path or filename)"
|
| 174 |
)
|
| 175 |
with gr.Column(scale=1, min_width=90):
|
|
@@ -180,15 +195,19 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 180 |
available_models = dit_handler.get_available_acestep_v15_models()
|
| 181 |
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 182 |
|
|
|
|
|
|
|
| 183 |
config_path = gr.Dropdown(
|
| 184 |
label="Main Model Path",
|
| 185 |
choices=available_models,
|
| 186 |
-
value=
|
| 187 |
info="Select the model configuration directory (auto-scanned from checkpoints)"
|
| 188 |
)
|
|
|
|
|
|
|
| 189 |
device = gr.Dropdown(
|
| 190 |
choices=["auto", "cuda", "cpu"],
|
| 191 |
-
value=
|
| 192 |
label="Device",
|
| 193 |
info="Processing device (auto-detect recommended)"
|
| 194 |
)
|
|
@@ -198,47 +217,61 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 198 |
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 199 |
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 200 |
|
|
|
|
|
|
|
| 201 |
lm_model_path = gr.Dropdown(
|
| 202 |
label="5Hz LM Model Path",
|
| 203 |
choices=available_lm_models,
|
| 204 |
-
value=
|
| 205 |
info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
|
| 206 |
)
|
|
|
|
|
|
|
| 207 |
backend_dropdown = gr.Dropdown(
|
| 208 |
choices=["vllm", "pt"],
|
| 209 |
-
value=
|
| 210 |
label="5Hz LM Backend",
|
| 211 |
info="Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)"
|
| 212 |
)
|
| 213 |
|
| 214 |
# Checkbox options section - all checkboxes grouped together
|
| 215 |
with gr.Row():
|
|
|
|
|
|
|
| 216 |
init_llm_checkbox = gr.Checkbox(
|
| 217 |
label="Initialize 5Hz LM",
|
| 218 |
-
value=
|
| 219 |
info="Check to initialize 5Hz LM during service initialization",
|
| 220 |
)
|
| 221 |
# Auto-detect flash attention availability
|
| 222 |
flash_attn_available = dit_handler.is_flash_attention_available()
|
|
|
|
|
|
|
| 223 |
use_flash_attention_checkbox = gr.Checkbox(
|
| 224 |
label="Use Flash Attention",
|
| 225 |
-
value=
|
| 226 |
interactive=flash_attn_available,
|
| 227 |
info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
|
| 228 |
)
|
|
|
|
|
|
|
| 229 |
offload_to_cpu_checkbox = gr.Checkbox(
|
| 230 |
label="Offload to CPU",
|
| 231 |
-
value=
|
| 232 |
info="Offload models to CPU when not in use to save GPU memory"
|
| 233 |
)
|
|
|
|
|
|
|
| 234 |
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 235 |
label="Offload DiT to CPU",
|
| 236 |
-
value=
|
| 237 |
info="Offload DiT to CPU (needs Offload to CPU)"
|
| 238 |
)
|
| 239 |
|
| 240 |
init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
|
| 241 |
-
init_status
|
|
|
|
|
|
|
| 242 |
|
| 243 |
# Inputs
|
| 244 |
with gr.Row():
|
|
@@ -328,7 +361,7 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 328 |
label="Temperature",
|
| 329 |
minimum=0.0,
|
| 330 |
maximum=2.0,
|
| 331 |
-
value=0.
|
| 332 |
step=0.1,
|
| 333 |
scale=1,
|
| 334 |
info="Temperature for 5Hz LM sampling (higher = more random, lower = more deterministic)"
|
|
@@ -337,18 +370,48 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 337 |
label="CFG Scale",
|
| 338 |
minimum=1.0,
|
| 339 |
maximum=3.0,
|
| 340 |
-
value=
|
| 341 |
step=0.1,
|
| 342 |
scale=1,
|
| 343 |
info="Classifier-Free Guidance scale for 5Hz LM (1.0 = no CFG, higher = stronger guidance)"
|
| 344 |
)
|
| 345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
# Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
|
| 347 |
lm_negative_prompt = gr.Textbox(
|
| 348 |
label="Negative Prompt",
|
| 349 |
value="NO USER INPUT",
|
| 350 |
placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 351 |
-
visible=
|
| 352 |
info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
|
| 353 |
lines=2
|
| 354 |
)
|
|
@@ -377,7 +440,7 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 377 |
step=0.01,
|
| 378 |
label="Audio Cover Strength",
|
| 379 |
info="Control how many denoising steps use cover mode",
|
| 380 |
-
visible=
|
| 381 |
)
|
| 382 |
|
| 383 |
# Music Caption
|
|
@@ -514,7 +577,9 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 514 |
interactive=False
|
| 515 |
)
|
| 516 |
|
| 517 |
-
generate_btn
|
|
|
|
|
|
|
| 518 |
|
| 519 |
return {
|
| 520 |
"checkpoint_dropdown": checkpoint_dropdown,
|
|
@@ -542,6 +607,9 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
|
|
| 542 |
"use_5hz_lm_btn": use_5hz_lm_btn,
|
| 543 |
"lm_temperature": lm_temperature,
|
| 544 |
"lm_cfg_scale": lm_cfg_scale,
|
|
|
|
|
|
|
|
|
|
| 545 |
"lm_negative_prompt": lm_negative_prompt,
|
| 546 |
"repainting_group": repainting_group,
|
| 547 |
"repainting_start": repainting_start,
|
|
@@ -733,6 +801,47 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 733 |
|
| 734 |
return status, gr.update(interactive=enable)
|
| 735 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
generation_section["init_btn"].click(
|
| 737 |
fn=init_service_wrapper,
|
| 738 |
inputs=[
|
|
@@ -749,30 +858,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 749 |
outputs=[generation_section["init_status"], generation_section["generate_btn"]]
|
| 750 |
)
|
| 751 |
|
| 752 |
-
# Update negative prompt visibility based on LM initialization and CFG scale
|
| 753 |
-
def update_negative_prompt_visibility(init_status, cfg_scale):
|
| 754 |
-
"""Update negative prompt visibility: show only if LM initialized and cfg_scale > 1"""
|
| 755 |
-
# Check if LM is initialized by looking for "5Hz LM backend:" in status
|
| 756 |
-
lm_initialized = init_status is not None and "5Hz LM backend:" in str(init_status)
|
| 757 |
-
# Check if cfg_scale > 1
|
| 758 |
-
cfg_enabled = cfg_scale is not None and float(cfg_scale) > 1.0
|
| 759 |
-
# Show only if both conditions are met
|
| 760 |
-
return gr.update(visible=lm_initialized and cfg_enabled)
|
| 761 |
-
|
| 762 |
-
# Update visibility when init_status changes
|
| 763 |
-
generation_section["init_status"].change(
|
| 764 |
-
fn=update_negative_prompt_visibility,
|
| 765 |
-
inputs=[generation_section["init_status"], generation_section["lm_cfg_scale"]],
|
| 766 |
-
outputs=[generation_section["lm_negative_prompt"]]
|
| 767 |
-
)
|
| 768 |
-
|
| 769 |
-
# Update visibility when cfg_scale changes
|
| 770 |
-
generation_section["lm_cfg_scale"].change(
|
| 771 |
-
fn=update_negative_prompt_visibility,
|
| 772 |
-
inputs=[generation_section["init_status"], generation_section["lm_cfg_scale"]],
|
| 773 |
-
outputs=[generation_section["lm_negative_prompt"]]
|
| 774 |
-
)
|
| 775 |
-
|
| 776 |
# Generation with progress bar
|
| 777 |
def generate_with_progress(
|
| 778 |
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
|
@@ -845,9 +930,16 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 845 |
)
|
| 846 |
|
| 847 |
# 5Hz LM generation (simplified version, can be extended as needed)
|
| 848 |
-
def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, negative_prompt):
|
| 849 |
"""Wrapper for 5Hz LM generation"""
|
| 850 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
# Extract metadata values and map to UI fields
|
| 853 |
# Handle bpm
|
|
@@ -886,6 +978,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 886 |
generation_section["lyrics"],
|
| 887 |
generation_section["lm_temperature"],
|
| 888 |
generation_section["lm_cfg_scale"],
|
|
|
|
|
|
|
|
|
|
| 889 |
generation_section["lm_negative_prompt"]
|
| 890 |
],
|
| 891 |
outputs=[
|
|
@@ -902,7 +997,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 902 |
task_type_value: str,
|
| 903 |
track_name_value: Optional[str],
|
| 904 |
complete_track_classes_value: list,
|
| 905 |
-
audio_codes_content: str = ""
|
|
|
|
| 906 |
) -> tuple:
|
| 907 |
"""Update instruction and UI visibility based on task type."""
|
| 908 |
instruction = dit_handler.generate_instruction(
|
|
@@ -915,8 +1011,15 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 915 |
track_name_visible = task_type_value in ["lego", "extract"]
|
| 916 |
# Show complete_track_classes for complete
|
| 917 |
complete_visible = task_type_value == "complete"
|
| 918 |
-
# Show audio_cover_strength for cover
|
| 919 |
-
audio_cover_strength_visible = task_type_value == "cover"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
# Show audio_code_string for cover
|
| 921 |
audio_code_visible = task_type_value == "cover"
|
| 922 |
# Show repainting controls for repaint and lego
|
|
@@ -932,7 +1035,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 932 |
instruction, # instruction_display_gen
|
| 933 |
gr.update(visible=track_name_visible), # track_name
|
| 934 |
gr.update(visible=complete_visible), # complete_track_classes
|
| 935 |
-
gr.update(visible=audio_cover_strength_visible), # audio_cover_strength
|
| 936 |
gr.update(visible=repainting_visible), # repainting_group
|
| 937 |
gr.update(visible=audio_code_visible), # audio_code_string
|
| 938 |
gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
|
|
@@ -946,7 +1049,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 946 |
generation_section["task_type"],
|
| 947 |
generation_section["track_name"],
|
| 948 |
generation_section["complete_track_classes"],
|
| 949 |
-
generation_section["text2music_audio_code_string"]
|
|
|
|
| 950 |
],
|
| 951 |
outputs=[
|
| 952 |
generation_section["instruction_display_gen"],
|
|
@@ -967,7 +1071,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 967 |
generation_section["task_type"],
|
| 968 |
generation_section["track_name"],
|
| 969 |
generation_section["complete_track_classes"],
|
| 970 |
-
generation_section["text2music_audio_code_string"]
|
|
|
|
| 971 |
],
|
| 972 |
outputs=[
|
| 973 |
generation_section["instruction_display_gen"],
|
|
@@ -988,7 +1093,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 988 |
generation_section["task_type"],
|
| 989 |
generation_section["track_name"],
|
| 990 |
generation_section["complete_track_classes"],
|
| 991 |
-
generation_section["text2music_audio_code_string"]
|
|
|
|
| 992 |
],
|
| 993 |
outputs=[
|
| 994 |
generation_section["instruction_display_gen"],
|
|
|
|
| 7 |
from typing import Callable, Optional
|
| 8 |
|
| 9 |
|
| 10 |
+
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None) -> gr.Blocks:
|
| 11 |
"""
|
| 12 |
Create Gradio interface
|
| 13 |
|
|
|
|
| 15 |
dit_handler: DiT handler instance
|
| 16 |
llm_handler: LM handler instance
|
| 17 |
dataset_handler: Dataset handler instance
|
| 18 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 19 |
+
If None, service will not be pre-initialized.
|
| 20 |
|
| 21 |
Returns:
|
| 22 |
Gradio Blocks instance
|
|
|
|
| 49 |
# Dataset Explorer Section
|
| 50 |
dataset_section = create_dataset_section(dataset_handler)
|
| 51 |
|
| 52 |
+
# Generation Section (pass init_params to support pre-initialization)
|
| 53 |
+
generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params)
|
| 54 |
|
| 55 |
# Results Section
|
| 56 |
results_section = create_results_section(dit_handler)
|
|
|
|
| 158 |
}
|
| 159 |
|
| 160 |
|
| 161 |
+
def create_generation_section(dit_handler, llm_handler, init_params=None) -> dict:
|
| 162 |
+
"""Create generation section
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
dit_handler: DiT handler instance
|
| 166 |
+
llm_handler: LM handler instance
|
| 167 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 168 |
+
If None, service will not be pre-initialized.
|
| 169 |
+
"""
|
| 170 |
+
# Check if service is pre-initialized
|
| 171 |
+
service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
|
| 172 |
+
|
| 173 |
with gr.Group():
|
| 174 |
gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
|
| 175 |
|
| 176 |
+
# Service Configuration - collapse if pre-initialized
|
| 177 |
+
accordion_open = not service_pre_initialized
|
| 178 |
+
with gr.Accordion("🔧 Service Configuration", open=accordion_open) as service_config_accordion:
|
| 179 |
# Dropdown options section - all dropdowns grouped together
|
| 180 |
with gr.Row(equal_height=True):
|
| 181 |
with gr.Column(scale=4):
|
| 182 |
+
# Set checkpoint value from init_params if pre-initialized
|
| 183 |
+
checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
|
| 184 |
checkpoint_dropdown = gr.Dropdown(
|
| 185 |
label="Checkpoint File",
|
| 186 |
choices=dit_handler.get_available_checkpoints(),
|
| 187 |
+
value=checkpoint_value,
|
| 188 |
info="Select a trained model checkpoint file (full path or filename)"
|
| 189 |
)
|
| 190 |
with gr.Column(scale=1, min_width=90):
|
|
|
|
| 195 |
available_models = dit_handler.get_available_acestep_v15_models()
|
| 196 |
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 197 |
|
| 198 |
+
# Set config_path value from init_params if pre-initialized
|
| 199 |
+
config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 200 |
config_path = gr.Dropdown(
|
| 201 |
label="Main Model Path",
|
| 202 |
choices=available_models,
|
| 203 |
+
value=config_path_value,
|
| 204 |
info="Select the model configuration directory (auto-scanned from checkpoints)"
|
| 205 |
)
|
| 206 |
+
# Set device value from init_params if pre-initialized
|
| 207 |
+
device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
|
| 208 |
device = gr.Dropdown(
|
| 209 |
choices=["auto", "cuda", "cpu"],
|
| 210 |
+
value=device_value,
|
| 211 |
label="Device",
|
| 212 |
info="Processing device (auto-detect recommended)"
|
| 213 |
)
|
|
|
|
| 217 |
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 218 |
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 219 |
|
| 220 |
+
# Set lm_model_path value from init_params if pre-initialized
|
| 221 |
+
lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
|
| 222 |
lm_model_path = gr.Dropdown(
|
| 223 |
label="5Hz LM Model Path",
|
| 224 |
choices=available_lm_models,
|
| 225 |
+
value=lm_model_path_value,
|
| 226 |
info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
|
| 227 |
)
|
| 228 |
+
# Set backend value from init_params if pre-initialized
|
| 229 |
+
backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
|
| 230 |
backend_dropdown = gr.Dropdown(
|
| 231 |
choices=["vllm", "pt"],
|
| 232 |
+
value=backend_value,
|
| 233 |
label="5Hz LM Backend",
|
| 234 |
info="Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)"
|
| 235 |
)
|
| 236 |
|
| 237 |
# Checkbox options section - all checkboxes grouped together
|
| 238 |
with gr.Row():
|
| 239 |
+
# Set init_llm value from init_params if pre-initialized
|
| 240 |
+
init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
|
| 241 |
init_llm_checkbox = gr.Checkbox(
|
| 242 |
label="Initialize 5Hz LM",
|
| 243 |
+
value=init_llm_value,
|
| 244 |
info="Check to initialize 5Hz LM during service initialization",
|
| 245 |
)
|
| 246 |
# Auto-detect flash attention availability
|
| 247 |
flash_attn_available = dit_handler.is_flash_attention_available()
|
| 248 |
+
# Set use_flash_attention value from init_params if pre-initialized
|
| 249 |
+
use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
|
| 250 |
use_flash_attention_checkbox = gr.Checkbox(
|
| 251 |
label="Use Flash Attention",
|
| 252 |
+
value=use_flash_attention_value,
|
| 253 |
interactive=flash_attn_available,
|
| 254 |
info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
|
| 255 |
)
|
| 256 |
+
# Set offload_to_cpu value from init_params if pre-initialized
|
| 257 |
+
offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
|
| 258 |
offload_to_cpu_checkbox = gr.Checkbox(
|
| 259 |
label="Offload to CPU",
|
| 260 |
+
value=offload_to_cpu_value,
|
| 261 |
info="Offload models to CPU when not in use to save GPU memory"
|
| 262 |
)
|
| 263 |
+
# Set offload_dit_to_cpu value from init_params if pre-initialized
|
| 264 |
+
offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
|
| 265 |
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 266 |
label="Offload DiT to CPU",
|
| 267 |
+
value=offload_dit_to_cpu_value,
|
| 268 |
info="Offload DiT to CPU (needs Offload to CPU)"
|
| 269 |
)
|
| 270 |
|
| 271 |
init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
|
| 272 |
+
# Set init_status value from init_params if pre-initialized
|
| 273 |
+
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 274 |
+
init_status = gr.Textbox(label="Status", interactive=False, lines=3, value=init_status_value)
|
| 275 |
|
| 276 |
# Inputs
|
| 277 |
with gr.Row():
|
|
|
|
| 361 |
label="Temperature",
|
| 362 |
minimum=0.0,
|
| 363 |
maximum=2.0,
|
| 364 |
+
value=0.85,
|
| 365 |
step=0.1,
|
| 366 |
scale=1,
|
| 367 |
info="Temperature for 5Hz LM sampling (higher = more random, lower = more deterministic)"
|
|
|
|
| 370 |
label="CFG Scale",
|
| 371 |
minimum=1.0,
|
| 372 |
maximum=3.0,
|
| 373 |
+
value=2.0,
|
| 374 |
step=0.1,
|
| 375 |
scale=1,
|
| 376 |
info="Classifier-Free Guidance scale for 5Hz LM (1.0 = no CFG, higher = stronger guidance)"
|
| 377 |
)
|
| 378 |
|
| 379 |
+
with gr.Row():
|
| 380 |
+
lm_top_k = gr.Slider(
|
| 381 |
+
label="Top-K",
|
| 382 |
+
minimum=0,
|
| 383 |
+
maximum=100,
|
| 384 |
+
value=0,
|
| 385 |
+
step=1,
|
| 386 |
+
scale=1,
|
| 387 |
+
info="Top-K sampling: consider only top K tokens (0 = disabled)"
|
| 388 |
+
)
|
| 389 |
+
lm_top_p = gr.Slider(
|
| 390 |
+
label="Top-P",
|
| 391 |
+
minimum=0.0,
|
| 392 |
+
maximum=1.0,
|
| 393 |
+
value=0.9,
|
| 394 |
+
step=0.01,
|
| 395 |
+
scale=1,
|
| 396 |
+
info="Top-P (nucleus) sampling: cumulative probability threshold (1.0 = disabled)"
|
| 397 |
+
)
|
| 398 |
+
lm_repetition_penalty = gr.Slider(
|
| 399 |
+
label="Repetition Penalty",
|
| 400 |
+
minimum=0.8,
|
| 401 |
+
maximum=1.2,
|
| 402 |
+
value=1.0,
|
| 403 |
+
step=0.01,
|
| 404 |
+
scale=1,
|
| 405 |
+
info="Repetition penalty: >1.0 reduces repetition, <1.0 increases it (1.0 = no penalty). For audio generation, use 1.0 or very small values (1.01-1.05) as audio tokens naturally repeat.",
|
| 406 |
+
visible=False,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
# Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
|
| 410 |
lm_negative_prompt = gr.Textbox(
|
| 411 |
label="Negative Prompt",
|
| 412 |
value="NO USER INPUT",
|
| 413 |
placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 414 |
+
visible=True,
|
| 415 |
info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
|
| 416 |
lines=2
|
| 417 |
)
|
|
|
|
| 440 |
step=0.01,
|
| 441 |
label="Audio Cover Strength",
|
| 442 |
info="Control how many denoising steps use cover mode",
|
| 443 |
+
visible=True
|
| 444 |
)
|
| 445 |
|
| 446 |
# Music Caption
|
|
|
|
| 577 |
interactive=False
|
| 578 |
)
|
| 579 |
|
| 580 |
+
# Set generate_btn to interactive if service is pre-initialized
|
| 581 |
+
generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
|
| 582 |
+
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive)
|
| 583 |
|
| 584 |
return {
|
| 585 |
"checkpoint_dropdown": checkpoint_dropdown,
|
|
|
|
| 607 |
"use_5hz_lm_btn": use_5hz_lm_btn,
|
| 608 |
"lm_temperature": lm_temperature,
|
| 609 |
"lm_cfg_scale": lm_cfg_scale,
|
| 610 |
+
"lm_top_k": lm_top_k,
|
| 611 |
+
"lm_top_p": lm_top_p,
|
| 612 |
+
"lm_repetition_penalty": lm_repetition_penalty,
|
| 613 |
"lm_negative_prompt": lm_negative_prompt,
|
| 614 |
"repainting_group": repainting_group,
|
| 615 |
"repainting_start": repainting_start,
|
|
|
|
| 801 |
|
| 802 |
return status, gr.update(interactive=enable)
|
| 803 |
|
| 804 |
+
# Update negative prompt visibility based on "Initialize 5Hz LM" checkbox
|
| 805 |
+
def update_negative_prompt_visibility(init_llm_checked):
|
| 806 |
+
"""Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
|
| 807 |
+
return gr.update(visible=init_llm_checked)
|
| 808 |
+
|
| 809 |
+
# Update audio_cover_strength visibility and label based on task type and LM initialization
|
| 810 |
+
def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
|
| 811 |
+
"""Update audio_cover_strength visibility and label"""
|
| 812 |
+
# Show if task is cover OR if LM is initialized
|
| 813 |
+
is_visible = (task_type_value == "cover") or init_llm_checked
|
| 814 |
+
# Change label based on context
|
| 815 |
+
if init_llm_checked and task_type_value != "cover":
|
| 816 |
+
label = "LM codes strength"
|
| 817 |
+
info = "Control how many denoising steps use LM-generated codes"
|
| 818 |
+
else:
|
| 819 |
+
label = "Audio Cover Strength"
|
| 820 |
+
info = "Control how many denoising steps use cover mode"
|
| 821 |
+
|
| 822 |
+
return gr.update(visible=is_visible, label=label, info=info)
|
| 823 |
+
|
| 824 |
+
# Update visibility when init_llm_checkbox changes
|
| 825 |
+
generation_section["init_llm_checkbox"].change(
|
| 826 |
+
fn=update_negative_prompt_visibility,
|
| 827 |
+
inputs=[generation_section["init_llm_checkbox"]],
|
| 828 |
+
outputs=[generation_section["lm_negative_prompt"]]
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# Update audio_cover_strength visibility and label when init_llm_checkbox changes
|
| 832 |
+
generation_section["init_llm_checkbox"].change(
|
| 833 |
+
fn=update_audio_cover_strength_visibility,
|
| 834 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
|
| 835 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# Also update audio_cover_strength when task_type changes (to handle label changes)
|
| 839 |
+
generation_section["task_type"].change(
|
| 840 |
+
fn=update_audio_cover_strength_visibility,
|
| 841 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
|
| 842 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
generation_section["init_btn"].click(
|
| 846 |
fn=init_service_wrapper,
|
| 847 |
inputs=[
|
|
|
|
| 858 |
outputs=[generation_section["init_status"], generation_section["generate_btn"]]
|
| 859 |
)
|
| 860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
# Generation with progress bar
|
| 862 |
def generate_with_progress(
|
| 863 |
captions, lyrics, bpm, key_scale, time_signature, vocal_language,
|
|
|
|
| 930 |
)
|
| 931 |
|
| 932 |
# 5Hz LM generation (simplified version, can be extended as needed)
|
| 933 |
+
def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, top_k, top_p, repetition_penalty, negative_prompt):
|
| 934 |
"""Wrapper for 5Hz LM generation"""
|
| 935 |
+
# Convert top_k: 0 means None (disabled)
|
| 936 |
+
top_k_value = None if top_k == 0 else int(top_k)
|
| 937 |
+
# Convert top_p: 1.0 means None (disabled)
|
| 938 |
+
top_p_value = None if top_p >= 1.0 else top_p
|
| 939 |
+
metadata, audio_codes, status = llm_handler.generate_with_5hz_lm(
|
| 940 |
+
caption, lyrics, temperature, cfg_scale, negative_prompt,
|
| 941 |
+
top_k_value, top_p_value, repetition_penalty
|
| 942 |
+
)
|
| 943 |
|
| 944 |
# Extract metadata values and map to UI fields
|
| 945 |
# Handle bpm
|
|
|
|
| 978 |
generation_section["lyrics"],
|
| 979 |
generation_section["lm_temperature"],
|
| 980 |
generation_section["lm_cfg_scale"],
|
| 981 |
+
generation_section["lm_top_k"],
|
| 982 |
+
generation_section["lm_top_p"],
|
| 983 |
+
generation_section["lm_repetition_penalty"],
|
| 984 |
generation_section["lm_negative_prompt"]
|
| 985 |
],
|
| 986 |
outputs=[
|
|
|
|
| 997 |
task_type_value: str,
|
| 998 |
track_name_value: Optional[str],
|
| 999 |
complete_track_classes_value: list,
|
| 1000 |
+
audio_codes_content: str = "",
|
| 1001 |
+
init_llm_checked: bool = False
|
| 1002 |
) -> tuple:
|
| 1003 |
"""Update instruction and UI visibility based on task type."""
|
| 1004 |
instruction = dit_handler.generate_instruction(
|
|
|
|
| 1011 |
track_name_visible = task_type_value in ["lego", "extract"]
|
| 1012 |
# Show complete_track_classes for complete
|
| 1013 |
complete_visible = task_type_value == "complete"
|
| 1014 |
+
# Show audio_cover_strength for cover OR when LM is initialized
|
| 1015 |
+
audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked
|
| 1016 |
+
# Determine label and info based on context
|
| 1017 |
+
if init_llm_checked and task_type_value != "cover":
|
| 1018 |
+
audio_cover_strength_label = "LM codes strength"
|
| 1019 |
+
audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
|
| 1020 |
+
else:
|
| 1021 |
+
audio_cover_strength_label = "Audio Cover Strength"
|
| 1022 |
+
audio_cover_strength_info = "Control how many denoising steps use cover mode"
|
| 1023 |
# Show audio_code_string for cover
|
| 1024 |
audio_code_visible = task_type_value == "cover"
|
| 1025 |
# Show repainting controls for repaint and lego
|
|
|
|
| 1035 |
instruction, # instruction_display_gen
|
| 1036 |
gr.update(visible=track_name_visible), # track_name
|
| 1037 |
gr.update(visible=complete_visible), # complete_track_classes
|
| 1038 |
+
gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
|
| 1039 |
gr.update(visible=repainting_visible), # repainting_group
|
| 1040 |
gr.update(visible=audio_code_visible), # audio_code_string
|
| 1041 |
gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
|
|
|
|
| 1049 |
generation_section["task_type"],
|
| 1050 |
generation_section["track_name"],
|
| 1051 |
generation_section["complete_track_classes"],
|
| 1052 |
+
generation_section["text2music_audio_code_string"],
|
| 1053 |
+
generation_section["init_llm_checkbox"]
|
| 1054 |
],
|
| 1055 |
outputs=[
|
| 1056 |
generation_section["instruction_display_gen"],
|
|
|
|
| 1071 |
generation_section["task_type"],
|
| 1072 |
generation_section["track_name"],
|
| 1073 |
generation_section["complete_track_classes"],
|
| 1074 |
+
generation_section["text2music_audio_code_string"],
|
| 1075 |
+
generation_section["init_llm_checkbox"]
|
| 1076 |
],
|
| 1077 |
outputs=[
|
| 1078 |
generation_section["instruction_display_gen"],
|
|
|
|
| 1093 |
generation_section["task_type"],
|
| 1094 |
generation_section["track_name"],
|
| 1095 |
generation_section["complete_track_classes"],
|
| 1096 |
+
generation_section["text2music_audio_code_string"],
|
| 1097 |
+
generation_section["init_llm_checkbox"]
|
| 1098 |
],
|
| 1099 |
outputs=[
|
| 1100 |
generation_section["instruction_display_gen"],
|
acestep/handler.py
CHANGED
|
@@ -1362,7 +1362,7 @@ class AceStepHandler:
|
|
| 1362 |
|
| 1363 |
padded_non_cover_text_input_ids = None
|
| 1364 |
padded_non_cover_text_attention_masks = None
|
| 1365 |
-
if audio_cover_strength < 1.0
|
| 1366 |
non_cover_text_input_ids = []
|
| 1367 |
non_cover_text_attention_masks = []
|
| 1368 |
for i in range(batch_size):
|
|
@@ -1381,8 +1381,9 @@ class AceStepHandler:
|
|
| 1381 |
return_tensors="pt",
|
| 1382 |
)
|
| 1383 |
text_token_ids = text_inputs_dict.input_ids[0]
|
|
|
|
| 1384 |
non_cover_text_input_ids.append(text_token_ids)
|
| 1385 |
-
non_cover_text_attention_masks.append(
|
| 1386 |
|
| 1387 |
padded_non_cover_text_input_ids = torch.stack([
|
| 1388 |
torch.nn.functional.pad(
|
|
|
|
| 1362 |
|
| 1363 |
padded_non_cover_text_input_ids = None
|
| 1364 |
padded_non_cover_text_attention_masks = None
|
| 1365 |
+
if audio_cover_strength < 1.0:
|
| 1366 |
non_cover_text_input_ids = []
|
| 1367 |
non_cover_text_attention_masks = []
|
| 1368 |
for i in range(batch_size):
|
|
|
|
| 1381 |
return_tensors="pt",
|
| 1382 |
)
|
| 1383 |
text_token_ids = text_inputs_dict.input_ids[0]
|
| 1384 |
+
non_cover_text_attention_mask = text_inputs_dict.attention_mask[0].bool()
|
| 1385 |
non_cover_text_input_ids.append(text_token_ids)
|
| 1386 |
+
non_cover_text_attention_masks.append(non_cover_text_attention_mask)
|
| 1387 |
|
| 1388 |
padded_non_cover_text_input_ids = torch.stack([
|
| 1389 |
torch.nn.functional.pad(
|
acestep/llm_inference.py
CHANGED
|
@@ -11,8 +11,18 @@ from contextlib import contextmanager
|
|
| 11 |
import torch
|
| 12 |
from tqdm import tqdm
|
| 13 |
from loguru import logger
|
| 14 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
from transformers.generation.streamers import BaseStreamer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
class LLMHandler:
|
|
@@ -209,7 +219,17 @@ class LLMHandler:
|
|
| 209 |
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 210 |
return error_msg
|
| 211 |
|
| 212 |
-
def generate_with_5hz_lm_vllm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend"""
|
| 214 |
try:
|
| 215 |
from nanovllm import SamplingParams
|
|
@@ -226,7 +246,14 @@ class LLMHandler:
|
|
| 226 |
)
|
| 227 |
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 228 |
|
| 229 |
-
sampling_params = SamplingParams(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
# Use CFG if cfg_scale > 1.0
|
| 231 |
if cfg_scale > 1.0:
|
| 232 |
# Build unconditional prompt (user input replaced with "NO USER INPUT")
|
|
@@ -266,7 +293,17 @@ class LLMHandler:
|
|
| 266 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 267 |
return {}, "", error_msg
|
| 268 |
|
| 269 |
-
def generate_with_5hz_lm_pt(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
|
| 271 |
try:
|
| 272 |
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
|
@@ -295,7 +332,7 @@ class LLMHandler:
|
|
| 295 |
# Get max_new_tokens from model config or use a default
|
| 296 |
max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
|
| 297 |
if hasattr(self, 'max_model_len'):
|
| 298 |
-
max_new_tokens = min(max_new_tokens, self.max_model_len)
|
| 299 |
|
| 300 |
# Define custom streamer for tqdm
|
| 301 |
class TqdmTokenStreamer(BaseStreamer):
|
|
@@ -315,15 +352,78 @@ class LLMHandler:
|
|
| 315 |
|
| 316 |
streamer = TqdmTokenStreamer(total=max_new_tokens)
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
max_new_tokens=max_new_tokens,
|
| 322 |
temperature=temperature,
|
| 323 |
-
|
|
|
|
| 324 |
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 325 |
streamer=streamer,
|
| 326 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
# Decode the generated tokens
|
| 329 |
# Only decode the newly generated tokens (skip the input prompt)
|
|
@@ -338,7 +438,17 @@ class LLMHandler:
|
|
| 338 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 339 |
return {}, "", error_msg
|
| 340 |
|
| 341 |
-
def generate_with_5hz_lm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
"""Generate metadata and audio codes using 5Hz LM"""
|
| 343 |
# Check if 5Hz LM is initialized
|
| 344 |
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
|
@@ -355,9 +465,15 @@ class LLMHandler:
|
|
| 355 |
return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
|
| 356 |
|
| 357 |
if self.llm_backend == "vllm":
|
| 358 |
-
return self.generate_with_5hz_lm_vllm(
|
|
|
|
|
|
|
|
|
|
| 359 |
else:
|
| 360 |
-
return self.generate_with_5hz_lm_pt(
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
|
| 363 |
"""
|
|
@@ -440,6 +556,112 @@ class LLMHandler:
|
|
| 440 |
|
| 441 |
return metadata, audio_codes
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
@contextmanager
|
| 444 |
def _load_model_context(self):
|
| 445 |
"""
|
|
|
|
| 11 |
import torch
|
| 12 |
from tqdm import tqdm
|
| 13 |
from loguru import logger
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, ClassifierFreeGuidanceLogitsProcessor
|
| 15 |
from transformers.generation.streamers import BaseStreamer
|
| 16 |
+
from transformers.generation.logits_process import (
|
| 17 |
+
LogitsProcessorList,
|
| 18 |
+
LogitsProcessor,
|
| 19 |
+
TopKLogitsWarper,
|
| 20 |
+
TopPLogitsWarper,
|
| 21 |
+
RepetitionPenaltyLogitsProcessor,
|
| 22 |
+
TemperatureLogitsWarper,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
|
| 27 |
|
| 28 |
class LLMHandler:
|
|
|
|
| 219 |
error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 220 |
return error_msg
|
| 221 |
|
| 222 |
+
def generate_with_5hz_lm_vllm(
|
| 223 |
+
self,
|
| 224 |
+
caption: str,
|
| 225 |
+
lyrics: str,
|
| 226 |
+
temperature: float = 0.6,
|
| 227 |
+
cfg_scale: float = 1.0,
|
| 228 |
+
negative_prompt: str = "NO USER INPUT",
|
| 229 |
+
top_k: Optional[int] = None,
|
| 230 |
+
top_p: Optional[float] = None,
|
| 231 |
+
repetition_penalty: float = 1.0,
|
| 232 |
+
) -> Tuple[Dict[str, Any], str, str]:
|
| 233 |
"""Generate metadata and audio codes using 5Hz LM with vllm backend"""
|
| 234 |
try:
|
| 235 |
from nanovllm import SamplingParams
|
|
|
|
| 246 |
)
|
| 247 |
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 248 |
|
| 249 |
+
sampling_params = SamplingParams(
|
| 250 |
+
max_tokens=self.max_model_len-64,
|
| 251 |
+
temperature=temperature,
|
| 252 |
+
cfg_scale=cfg_scale,
|
| 253 |
+
top_k=top_k,
|
| 254 |
+
top_p=top_p,
|
| 255 |
+
repetition_penalty=repetition_penalty,
|
| 256 |
+
)
|
| 257 |
# Use CFG if cfg_scale > 1.0
|
| 258 |
if cfg_scale > 1.0:
|
| 259 |
# Build unconditional prompt (user input replaced with "NO USER INPUT")
|
|
|
|
| 293 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 294 |
return {}, "", error_msg
|
| 295 |
|
| 296 |
+
def generate_with_5hz_lm_pt(
|
| 297 |
+
self,
|
| 298 |
+
caption: str,
|
| 299 |
+
lyrics: str,
|
| 300 |
+
temperature: float = 0.6,
|
| 301 |
+
cfg_scale: float = 1.0,
|
| 302 |
+
negative_prompt: str = "NO USER INPUT",
|
| 303 |
+
top_k: Optional[int] = None,
|
| 304 |
+
top_p: Optional[float] = None,
|
| 305 |
+
repetition_penalty: float = 1.0,
|
| 306 |
+
) -> Tuple[Dict[str, Any], str, str]:
|
| 307 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
|
| 308 |
try:
|
| 309 |
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
|
|
|
| 332 |
# Get max_new_tokens from model config or use a default
|
| 333 |
max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
|
| 334 |
if hasattr(self, 'max_model_len'):
|
| 335 |
+
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
|
| 336 |
|
| 337 |
# Define custom streamer for tqdm
|
| 338 |
class TqdmTokenStreamer(BaseStreamer):
|
|
|
|
| 352 |
|
| 353 |
streamer = TqdmTokenStreamer(total=max_new_tokens)
|
| 354 |
|
| 355 |
+
# Build logits processor list
|
| 356 |
+
logits_processor = LogitsProcessorList()
|
| 357 |
+
|
| 358 |
+
# Add repetition penalty if needed
|
| 359 |
+
if repetition_penalty != 1.0:
|
| 360 |
+
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 361 |
+
|
| 362 |
+
# Add temperature warper if needed (temperature is handled separately in generate, but we can also use warper)
|
| 363 |
+
# Note: temperature is passed directly to generate(), but we can use TemperatureLogitsWarper for consistency
|
| 364 |
+
if temperature != 1.0:
|
| 365 |
+
logits_processor.append(TemperatureLogitsWarper(temperature=temperature))
|
| 366 |
+
|
| 367 |
+
# Add top-k warper if specified
|
| 368 |
+
if top_k is not None and top_k > 0:
|
| 369 |
+
logits_processor.append(TopKLogitsWarper(top_k=top_k))
|
| 370 |
+
|
| 371 |
+
# Add top-p warper if specified
|
| 372 |
+
if top_p is not None and top_p > 0.0 and top_p < 1.0:
|
| 373 |
+
logits_processor.append(TopPLogitsWarper(top_p=top_p))
|
| 374 |
+
|
| 375 |
+
# Handle CFG if cfg_scale > 1.0
|
| 376 |
+
if cfg_scale > 1.0:
|
| 377 |
+
# Build unconditional prompt
|
| 378 |
+
formatted_unconditional_prompt = self.llm_tokenizer.apply_chat_template(
|
| 379 |
+
[
|
| 380 |
+
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 381 |
+
{"role": "user", "content": negative_prompt}
|
| 382 |
+
],
|
| 383 |
+
tokenize=False,
|
| 384 |
+
add_generation_prompt=True,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Tokenize unconditional prompt
|
| 388 |
+
uncond_inputs = self.llm_tokenizer(
|
| 389 |
+
formatted_unconditional_prompt,
|
| 390 |
+
return_tensors="pt",
|
| 391 |
+
padding=False,
|
| 392 |
+
truncation=True,
|
| 393 |
+
)
|
| 394 |
+
uncond_inputs = {k: v.to(self.device) for k, v in uncond_inputs.items()}
|
| 395 |
+
|
| 396 |
+
# Use custom CFG generation with batch processing
|
| 397 |
+
# Combine conditional and unconditional inputs into a batch
|
| 398 |
+
# Format: [cond_input, uncond_input]
|
| 399 |
+
batch_input_ids = torch.cat([inputs['input_ids'], uncond_inputs['input_ids']], dim=0)
|
| 400 |
+
batch_attention_mask = None
|
| 401 |
+
if 'attention_mask' in inputs:
|
| 402 |
+
batch_attention_mask = torch.cat([inputs['attention_mask'], uncond_inputs.get('attention_mask', torch.ones_like(uncond_inputs['input_ids']))], dim=0)
|
| 403 |
+
|
| 404 |
+
# Custom CFG generation loop
|
| 405 |
+
outputs = self._generate_with_cfg(
|
| 406 |
+
batch_input_ids=batch_input_ids,
|
| 407 |
+
batch_attention_mask=batch_attention_mask,
|
| 408 |
max_new_tokens=max_new_tokens,
|
| 409 |
temperature=temperature,
|
| 410 |
+
cfg_scale=cfg_scale,
|
| 411 |
+
logits_processor=logits_processor,
|
| 412 |
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 413 |
streamer=streamer,
|
| 414 |
)
|
| 415 |
+
else:
|
| 416 |
+
# Generate without CFG
|
| 417 |
+
with torch.no_grad():
|
| 418 |
+
outputs = self.llm.generate(
|
| 419 |
+
**inputs,
|
| 420 |
+
max_new_tokens=max_new_tokens,
|
| 421 |
+
temperature=temperature if temperature > 0 else 1.0,
|
| 422 |
+
do_sample=True if temperature > 0 else False,
|
| 423 |
+
logits_processor=logits_processor if len(logits_processor) > 0 else None,
|
| 424 |
+
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 425 |
+
streamer=streamer,
|
| 426 |
+
)
|
| 427 |
|
| 428 |
# Decode the generated tokens
|
| 429 |
# Only decode the newly generated tokens (skip the input prompt)
|
|
|
|
| 438 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 439 |
return {}, "", error_msg
|
| 440 |
|
| 441 |
+
def generate_with_5hz_lm(
|
| 442 |
+
self,
|
| 443 |
+
caption: str,
|
| 444 |
+
lyrics: str,
|
| 445 |
+
temperature: float = 0.6,
|
| 446 |
+
cfg_scale: float = 1.0,
|
| 447 |
+
negative_prompt: str = "NO USER INPUT",
|
| 448 |
+
top_k: Optional[int] = None,
|
| 449 |
+
top_p: Optional[float] = None,
|
| 450 |
+
repetition_penalty: float = 1.0,
|
| 451 |
+
) -> Tuple[Dict[str, Any], str, str]:
|
| 452 |
"""Generate metadata and audio codes using 5Hz LM"""
|
| 453 |
# Check if 5Hz LM is initialized
|
| 454 |
if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
|
|
|
|
| 465 |
return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
|
| 466 |
|
| 467 |
if self.llm_backend == "vllm":
|
| 468 |
+
return self.generate_with_5hz_lm_vllm(
|
| 469 |
+
caption, lyrics, temperature, cfg_scale, negative_prompt,
|
| 470 |
+
top_k, top_p, repetition_penalty
|
| 471 |
+
)
|
| 472 |
else:
|
| 473 |
+
return self.generate_with_5hz_lm_pt(
|
| 474 |
+
caption, lyrics, temperature, cfg_scale, negative_prompt,
|
| 475 |
+
top_k, top_p, repetition_penalty
|
| 476 |
+
)
|
| 477 |
|
| 478 |
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
|
| 479 |
"""
|
|
|
|
| 556 |
|
| 557 |
return metadata, audio_codes
|
| 558 |
|
| 559 |
+
def _generate_with_cfg(
|
| 560 |
+
self,
|
| 561 |
+
batch_input_ids: torch.Tensor,
|
| 562 |
+
batch_attention_mask: Optional[torch.Tensor],
|
| 563 |
+
max_new_tokens: int,
|
| 564 |
+
temperature: float,
|
| 565 |
+
cfg_scale: float,
|
| 566 |
+
logits_processor: Optional[LogitsProcessorList],
|
| 567 |
+
pad_token_id: int,
|
| 568 |
+
streamer: Optional[BaseStreamer],
|
| 569 |
+
) -> torch.Tensor:
|
| 570 |
+
"""
|
| 571 |
+
Custom generation loop with CFG support using batch processing.
|
| 572 |
+
Batch format: [conditional_input, unconditional_input]
|
| 573 |
+
This properly utilizes KV cache by processing both sequences in parallel.
|
| 574 |
+
"""
|
| 575 |
+
model = self.llm
|
| 576 |
+
device = self.device
|
| 577 |
+
batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
|
| 578 |
+
cond_start_idx = 0
|
| 579 |
+
uncond_start_idx = batch_size
|
| 580 |
+
|
| 581 |
+
# Initialize generated sequences
|
| 582 |
+
generated_ids = batch_input_ids.clone()
|
| 583 |
+
if batch_attention_mask is not None:
|
| 584 |
+
attention_mask = batch_attention_mask.clone()
|
| 585 |
+
else:
|
| 586 |
+
attention_mask = torch.ones_like(batch_input_ids)
|
| 587 |
+
|
| 588 |
+
# Prepare model inputs
|
| 589 |
+
model_kwargs = {}
|
| 590 |
+
if batch_attention_mask is not None:
|
| 591 |
+
model_kwargs['attention_mask'] = attention_mask
|
| 592 |
+
|
| 593 |
+
# Past key values for KV cache (if model supports it)
|
| 594 |
+
past_key_values = None
|
| 595 |
+
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
|
| 596 |
+
|
| 597 |
+
with torch.no_grad():
|
| 598 |
+
for step in range(max_new_tokens):
|
| 599 |
+
# Forward pass for the entire batch (conditional + unconditional)
|
| 600 |
+
if past_key_values is None:
|
| 601 |
+
# First step: full forward pass
|
| 602 |
+
outputs = model(
|
| 603 |
+
input_ids=generated_ids,
|
| 604 |
+
**model_kwargs,
|
| 605 |
+
use_cache=use_cache,
|
| 606 |
+
)
|
| 607 |
+
else:
|
| 608 |
+
# Subsequent steps: only forward the last token (utilizing KV cache)
|
| 609 |
+
outputs = model(
|
| 610 |
+
input_ids=generated_ids[:, -1:],
|
| 611 |
+
past_key_values=past_key_values,
|
| 612 |
+
**model_kwargs,
|
| 613 |
+
use_cache=use_cache,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
# Get logits
|
| 617 |
+
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
| 618 |
+
|
| 619 |
+
# Split conditional and unconditional logits
|
| 620 |
+
cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
|
| 621 |
+
uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
|
| 622 |
+
|
| 623 |
+
# Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
|
| 624 |
+
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
| 625 |
+
|
| 626 |
+
# Apply logits processors (temperature, top-k, top-p, repetition penalty)
|
| 627 |
+
if logits_processor is not None:
|
| 628 |
+
# Get current input_ids for repetition penalty (only conditional part)
|
| 629 |
+
current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
|
| 630 |
+
for processor in logits_processor:
|
| 631 |
+
cfg_logits = processor(current_input_ids, cfg_logits)
|
| 632 |
+
|
| 633 |
+
# Apply temperature and sample
|
| 634 |
+
if temperature > 0:
|
| 635 |
+
cfg_logits = cfg_logits / temperature
|
| 636 |
+
probs = torch.softmax(cfg_logits, dim=-1)
|
| 637 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 638 |
+
else:
|
| 639 |
+
next_tokens = torch.argmax(cfg_logits, dim=-1)
|
| 640 |
+
|
| 641 |
+
# Update generated sequences (apply same token to both conditional and unconditional)
|
| 642 |
+
next_tokens = next_tokens.unsqueeze(1)
|
| 643 |
+
generated_ids = torch.cat([generated_ids, next_tokens.repeat(2, 1)], dim=1)
|
| 644 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
| 645 |
+
model_kwargs['attention_mask'] = attention_mask
|
| 646 |
+
|
| 647 |
+
# Update past_key_values for next iteration
|
| 648 |
+
if use_cache and hasattr(outputs, 'past_key_values'):
|
| 649 |
+
past_key_values = outputs.past_key_values
|
| 650 |
+
|
| 651 |
+
# Update streamer
|
| 652 |
+
if streamer is not None:
|
| 653 |
+
streamer.put(next_tokens[0]) # Only stream conditional tokens
|
| 654 |
+
|
| 655 |
+
# Check for EOS (simplified - you may want to check model's eos_token_id)
|
| 656 |
+
if (next_tokens[0] == pad_token_id).all():
|
| 657 |
+
break
|
| 658 |
+
|
| 659 |
+
if streamer is not None:
|
| 660 |
+
streamer.end()
|
| 661 |
+
|
| 662 |
+
# Return only conditional output
|
| 663 |
+
return generated_ids[cond_start_idx:cond_start_idx+batch_size]
|
| 664 |
+
|
| 665 |
@contextmanager
|
| 666 |
def _load_model_context(self):
|
| 667 |
"""
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
CHANGED
|
@@ -212,22 +212,37 @@ class ModelRunner:
|
|
| 212 |
"""Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
|
| 213 |
if is_cfg_batch:
|
| 214 |
# For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 215 |
-
# We only need
|
| 216 |
num_cond = len(seqs) // 2
|
| 217 |
temperatures = []
|
| 218 |
cfg_scales = []
|
|
|
|
|
|
|
|
|
|
| 219 |
for seq in seqs[:num_cond]:
|
| 220 |
temperatures.append(seq.temperature)
|
| 221 |
cfg_scales.append(seq.cfg_scale)
|
|
|
|
|
|
|
|
|
|
| 222 |
else:
|
| 223 |
temperatures = []
|
| 224 |
cfg_scales = []
|
|
|
|
|
|
|
|
|
|
| 225 |
for seq in seqs:
|
| 226 |
temperatures.append(seq.temperature)
|
| 227 |
cfg_scales.append(seq.cfg_scale)
|
|
|
|
|
|
|
|
|
|
| 228 |
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 229 |
cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
@torch.inference_mode()
|
| 233 |
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
|
@@ -274,7 +289,11 @@ class ModelRunner:
|
|
| 274 |
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 275 |
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 276 |
else self.prepare_decode(seqs))
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
|
| 279 |
# Run model forward (processes entire batch: cond + uncond)
|
| 280 |
logits_all = self.run_model(input_ids, positions, is_prefill)
|
|
@@ -285,12 +304,44 @@ class ModelRunner:
|
|
| 285 |
logits_cond = logits_all[:num_cond]
|
| 286 |
logits_uncond = logits_all[num_cond:]
|
| 287 |
|
| 288 |
-
# Apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
|
| 290 |
-
logits_cfg =
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
# Sample from CFG logits
|
| 293 |
-
token_ids_cfg = self.sampler(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
# Return token_ids (will be applied to both conditional and unconditional sequences)
|
| 296 |
return token_ids_cfg
|
|
@@ -300,11 +351,51 @@ class ModelRunner:
|
|
| 300 |
# Normal batch (non-CFG)
|
| 301 |
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 302 |
else self.prepare_decode(seqs))
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
logits = self.run_model(input_ids, positions, is_prefill)
|
| 305 |
reset_context()
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
@torch.inference_mode()
|
| 310 |
def capture_cudagraph(self):
|
|
|
|
| 212 |
"""Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
|
| 213 |
if is_cfg_batch:
|
| 214 |
# For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 215 |
+
# We only need parameters for conditional sequences (first half)
|
| 216 |
num_cond = len(seqs) // 2
|
| 217 |
temperatures = []
|
| 218 |
cfg_scales = []
|
| 219 |
+
top_ks = []
|
| 220 |
+
top_ps = []
|
| 221 |
+
repetition_penalties = []
|
| 222 |
for seq in seqs[:num_cond]:
|
| 223 |
temperatures.append(seq.temperature)
|
| 224 |
cfg_scales.append(seq.cfg_scale)
|
| 225 |
+
top_ks.append(seq.top_k if seq.top_k is not None else 0)
|
| 226 |
+
top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
|
| 227 |
+
repetition_penalties.append(seq.repetition_penalty)
|
| 228 |
else:
|
| 229 |
temperatures = []
|
| 230 |
cfg_scales = []
|
| 231 |
+
top_ks = []
|
| 232 |
+
top_ps = []
|
| 233 |
+
repetition_penalties = []
|
| 234 |
for seq in seqs:
|
| 235 |
temperatures.append(seq.temperature)
|
| 236 |
cfg_scales.append(seq.cfg_scale)
|
| 237 |
+
top_ks.append(seq.top_k if seq.top_k is not None else 0)
|
| 238 |
+
top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
|
| 239 |
+
repetition_penalties.append(seq.repetition_penalty)
|
| 240 |
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 241 |
cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 242 |
+
top_ks = torch.tensor(top_ks, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 243 |
+
top_ps = torch.tensor(top_ps, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 244 |
+
repetition_penalties = torch.tensor(repetition_penalties, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
| 245 |
+
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 246 |
|
| 247 |
@torch.inference_mode()
|
| 248 |
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
|
|
|
| 289 |
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 290 |
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 291 |
else self.prepare_decode(seqs))
|
| 292 |
+
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 293 |
+
if sample_params is not None:
|
| 294 |
+
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
| 295 |
+
else:
|
| 296 |
+
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
| 297 |
|
| 298 |
# Run model forward (processes entire batch: cond + uncond)
|
| 299 |
logits_all = self.run_model(input_ids, positions, is_prefill)
|
|
|
|
| 304 |
logits_cond = logits_all[:num_cond]
|
| 305 |
logits_uncond = logits_all[num_cond:]
|
| 306 |
|
| 307 |
+
# Apply repetition penalty to conditional logits (before CFG)
|
| 308 |
+
if repetition_penalties is not None:
|
| 309 |
+
for i, seq in enumerate(cond_seqs):
|
| 310 |
+
penalty = repetition_penalties[i].item()
|
| 311 |
+
if penalty != 1.0:
|
| 312 |
+
# Only penalize completion tokens (not prompt tokens)
|
| 313 |
+
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
|
| 314 |
+
if len(completion_tokens) > 0:
|
| 315 |
+
# Create token mask: mark tokens that appeared in completion
|
| 316 |
+
token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
|
| 317 |
+
token_mask[completion_tokens] = True
|
| 318 |
+
|
| 319 |
+
# Apply standard repetition penalty formula (matching transformers implementation):
|
| 320 |
+
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
| 321 |
+
penalty_scores = torch.where(
|
| 322 |
+
logits_cond[i] < 0,
|
| 323 |
+
logits_cond[i] * penalty,
|
| 324 |
+
logits_cond[i] / penalty
|
| 325 |
+
)
|
| 326 |
+
# Only apply penalty to tokens that appeared in completion
|
| 327 |
+
logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
|
| 328 |
+
|
| 329 |
+
# Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
|
| 330 |
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
|
| 331 |
+
logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
|
| 332 |
+
|
| 333 |
+
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 334 |
+
cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 335 |
|
| 336 |
# Sample from CFG logits
|
| 337 |
+
token_ids_cfg = self.sampler(
|
| 338 |
+
logits_cfg,
|
| 339 |
+
temperatures,
|
| 340 |
+
top_ks=top_ks if top_ks is not None else None,
|
| 341 |
+
top_ps=top_ps if top_ps is not None else None,
|
| 342 |
+
repetition_penalties=None, # Already applied above
|
| 343 |
+
input_ids=cond_input_ids,
|
| 344 |
+
).tolist()
|
| 345 |
|
| 346 |
# Return token_ids (will be applied to both conditional and unconditional sequences)
|
| 347 |
return token_ids_cfg
|
|
|
|
| 351 |
# Normal batch (non-CFG)
|
| 352 |
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 353 |
else self.prepare_decode(seqs))
|
| 354 |
+
sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
|
| 355 |
+
if sample_params is not None:
|
| 356 |
+
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
| 357 |
+
else:
|
| 358 |
+
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
| 359 |
logits = self.run_model(input_ids, positions, is_prefill)
|
| 360 |
reset_context()
|
| 361 |
+
|
| 362 |
+
if self.rank == 0:
|
| 363 |
+
# Apply repetition penalty to logits
|
| 364 |
+
if repetition_penalties is not None:
|
| 365 |
+
for i, seq in enumerate(seqs):
|
| 366 |
+
penalty = repetition_penalties[i].item()
|
| 367 |
+
if penalty != 1.0:
|
| 368 |
+
# Only penalize completion tokens (not prompt tokens)
|
| 369 |
+
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
|
| 370 |
+
if len(completion_tokens) > 0:
|
| 371 |
+
# Create token mask: mark tokens that appeared in completion
|
| 372 |
+
token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
|
| 373 |
+
token_mask[completion_tokens] = True
|
| 374 |
+
|
| 375 |
+
# Apply standard repetition penalty formula (matching transformers implementation):
|
| 376 |
+
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
| 377 |
+
penalty_scores = torch.where(
|
| 378 |
+
logits[i] < 0,
|
| 379 |
+
logits[i] * penalty,
|
| 380 |
+
logits[i] / penalty
|
| 381 |
+
)
|
| 382 |
+
# Only apply penalty to tokens that appeared in completion
|
| 383 |
+
logits[i] = torch.where(token_mask, penalty_scores, logits[i])
|
| 384 |
+
|
| 385 |
+
# Prepare input_ids for sampler
|
| 386 |
+
seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 387 |
+
|
| 388 |
+
token_ids = self.sampler(
|
| 389 |
+
logits,
|
| 390 |
+
temperatures,
|
| 391 |
+
top_ks=top_ks if top_ks is not None else None,
|
| 392 |
+
top_ps=top_ps if top_ps is not None else None,
|
| 393 |
+
repetition_penalties=None, # Already applied above
|
| 394 |
+
input_ids=seq_input_ids,
|
| 395 |
+
).tolist()
|
| 396 |
+
return token_ids
|
| 397 |
+
else:
|
| 398 |
+
return None
|
| 399 |
|
| 400 |
@torch.inference_mode()
|
| 401 |
def capture_cudagraph(self):
|
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py
CHANGED
|
@@ -28,6 +28,9 @@ class Sequence:
|
|
| 28 |
self.max_tokens = sampling_params.max_tokens
|
| 29 |
self.ignore_eos = sampling_params.ignore_eos
|
| 30 |
self.cfg_scale = sampling_params.cfg_scale
|
|
|
|
|
|
|
|
|
|
| 31 |
# For CFG: mark if this is an unconditional sequence
|
| 32 |
self.is_unconditional = is_unconditional
|
| 33 |
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
|
|
|
|
| 28 |
self.max_tokens = sampling_params.max_tokens
|
| 29 |
self.ignore_eos = sampling_params.ignore_eos
|
| 30 |
self.cfg_scale = sampling_params.cfg_scale
|
| 31 |
+
self.top_k = sampling_params.top_k
|
| 32 |
+
self.top_p = sampling_params.top_p
|
| 33 |
+
self.repetition_penalty = sampling_params.repetition_penalty
|
| 34 |
# For CFG: mark if this is an unconditional sequence
|
| 35 |
self.is_unconditional = is_unconditional
|
| 36 |
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
|
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
class Sampler(nn.Module):
|
|
@@ -8,8 +9,66 @@ class Sampler(nn.Module):
|
|
| 8 |
super().__init__()
|
| 9 |
|
| 10 |
@torch.compile
|
| 11 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
probs = torch.softmax(logits, dim=-1)
|
| 14 |
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 15 |
return sample_tokens
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
+
from typing import Optional
|
| 4 |
|
| 5 |
|
| 6 |
class Sampler(nn.Module):
|
|
|
|
| 9 |
super().__init__()
|
| 10 |
|
| 11 |
@torch.compile
|
| 12 |
+
def forward(
|
| 13 |
+
self,
|
| 14 |
+
logits: torch.Tensor,
|
| 15 |
+
temperatures: torch.Tensor,
|
| 16 |
+
top_ks: Optional[torch.Tensor] = None,
|
| 17 |
+
top_ps: Optional[torch.Tensor] = None,
|
| 18 |
+
repetition_penalties: Optional[torch.Tensor] = None,
|
| 19 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Sample tokens from logits with optional top-k, top-p, and repetition penalty.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
logits: [batch_size, vocab_size] logits tensor
|
| 26 |
+
temperatures: [batch_size] temperature values
|
| 27 |
+
top_ks: Optional [batch_size] top-k values (None or 0 means no top-k filtering)
|
| 28 |
+
top_ps: Optional [batch_size] top-p values (None or 1.0 means no top-p filtering)
|
| 29 |
+
repetition_penalties: Optional [batch_size] repetition penalty values (1.0 means no penalty)
|
| 30 |
+
input_ids: Optional [batch_size, seq_len] input token ids for repetition penalty
|
| 31 |
+
"""
|
| 32 |
+
batch_size, vocab_size = logits.shape
|
| 33 |
+
|
| 34 |
+
# Note: Repetition penalty is applied in ModelRunner before calling sampler
|
| 35 |
+
# This allows us to use the full sequence context
|
| 36 |
+
|
| 37 |
+
# Apply temperature
|
| 38 |
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 39 |
+
|
| 40 |
+
# Apply top-k filtering if specified
|
| 41 |
+
if top_ks is not None:
|
| 42 |
+
for i in range(batch_size):
|
| 43 |
+
top_k = top_ks[i].item()
|
| 44 |
+
if top_k > 0 and top_k < vocab_size:
|
| 45 |
+
# Get top-k logits, set others to -inf
|
| 46 |
+
top_k_logits, top_k_indices = torch.topk(logits[i], int(top_k), dim=-1)
|
| 47 |
+
filtered_logits = torch.full_like(logits[i], float('-inf'))
|
| 48 |
+
filtered_logits[top_k_indices] = top_k_logits
|
| 49 |
+
logits[i] = filtered_logits
|
| 50 |
+
|
| 51 |
+
# Apply top-p (nucleus) filtering if specified
|
| 52 |
+
if top_ps is not None:
|
| 53 |
+
probs = torch.softmax(logits, dim=-1)
|
| 54 |
+
for i in range(batch_size):
|
| 55 |
+
top_p = top_ps[i].item()
|
| 56 |
+
if 0.0 < top_p < 1.0:
|
| 57 |
+
# Sort probabilities in descending order
|
| 58 |
+
sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
|
| 59 |
+
# Calculate cumulative probabilities
|
| 60 |
+
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 61 |
+
# Find the cutoff point
|
| 62 |
+
cutoff_idx = (cumsum_probs <= top_p).sum().item()
|
| 63 |
+
if cutoff_idx < len(sorted_indices):
|
| 64 |
+
cutoff_idx += 1 # Include one more token to ensure we have at least one
|
| 65 |
+
# Create mask for tokens to keep
|
| 66 |
+
mask = torch.zeros_like(probs[i])
|
| 67 |
+
mask[sorted_indices[:cutoff_idx]] = 1.0
|
| 68 |
+
# Apply mask: set filtered tokens to -inf
|
| 69 |
+
logits[i] = torch.where(mask > 0, logits[i], torch.tensor(float('-inf'), device=logits.device))
|
| 70 |
+
|
| 71 |
+
# Sample using Gumbel-max trick (equivalent to sampling from softmax)
|
| 72 |
probs = torch.softmax(logits, dim=-1)
|
| 73 |
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 74 |
return sample_tokens
|
acestep/third_parts/nano-vllm/nanovllm/sampling_params.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from dataclasses import dataclass
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
@dataclass
|
|
@@ -7,7 +8,15 @@ class SamplingParams:
|
|
| 7 |
max_tokens: int = 64
|
| 8 |
ignore_eos: bool = False
|
| 9 |
cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def __post_init__(self):
|
| 12 |
assert self.temperature > 1e-10, "greedy sampling is not permitted"
|
| 13 |
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
|
| 4 |
|
| 5 |
@dataclass
|
|
|
|
| 8 |
max_tokens: int = 64
|
| 9 |
ignore_eos: bool = False
|
| 10 |
cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
|
| 11 |
+
top_k: Optional[int] = None # Top-k sampling: consider only top k tokens
|
| 12 |
+
top_p: Optional[float] = None # Top-p (nucleus) sampling: consider tokens with cumulative probability <= top_p
|
| 13 |
+
repetition_penalty: float = 1.0 # Repetition penalty: >1.0 reduces repetition, <1.0 increases it
|
| 14 |
|
| 15 |
def __post_init__(self):
|
| 16 |
assert self.temperature > 1e-10, "greedy sampling is not permitted"
|
| 17 |
assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
|
| 18 |
+
if self.top_k is not None:
|
| 19 |
+
assert self.top_k > 0, "top_k must be > 0"
|
| 20 |
+
if self.top_p is not None:
|
| 21 |
+
assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
|
| 22 |
+
assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"
|