ChuxiJ commited on
Commit
376c43e
·
1 Parent(s): 5ac3586

refact handler

Browse files
.gitignore CHANGED
@@ -213,3 +213,5 @@ tests/
213
  checkpoints/
214
  playground.ipynb
215
  .history/
 
 
 
213
  checkpoints/
214
  playground.ipynb
215
  .history/
216
+ upload_checkpoints.sh
217
+ checkpoints.7z
acestep/acestep_v15_pipeline.py CHANGED
@@ -15,20 +15,33 @@ from .dataset_handler import DatasetHandler
15
  from .gradio_ui import create_gradio_interface
16
 
17
 
18
- def create_demo():
19
  """
20
  Create Gradio demo interface
21
 
 
 
 
 
 
 
 
 
22
  Returns:
23
  Gradio Blocks instance
24
  """
25
- # Create independent handler instances
26
- dit_handler = AceStepHandler() # DiT handler
27
- llm_handler = LLMHandler() # LM handler
 
 
 
 
 
28
  dataset_handler = DatasetHandler() # Dataset handler
29
 
30
- # Create Gradio interface with all handlers
31
- demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler)
32
 
33
  return demo
34
 
@@ -42,12 +55,124 @@ def main():
42
  parser.add_argument("--share", action="store_true", help="Create a public link")
43
  parser.add_argument("--debug", action="store_true", help="Enable debug mode")
44
  parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  args = parser.parse_args()
46
 
47
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Create and launch demo
49
  print("Creating Gradio interface...")
50
- demo = create_demo()
51
  print(f"Launching server on {args.server_name}:{args.port}...")
52
  demo.launch(
53
  server_name=args.server_name,
 
15
  from .gradio_ui import create_gradio_interface
16
 
17
 
18
+ def create_demo(init_params=None):
19
  """
20
  Create Gradio demo interface
21
 
22
+ Args:
23
+ init_params: Dictionary containing initialization parameters and state.
24
+ If None, service will not be pre-initialized.
25
+ Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
26
+ 'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
27
+ 'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
28
+ 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized)
29
+
30
  Returns:
31
  Gradio Blocks instance
32
  """
33
+ # Use pre-initialized handlers if available, otherwise create new ones
34
+ if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
35
+ dit_handler = init_params['dit_handler']
36
+ llm_handler = init_params['llm_handler']
37
+ else:
38
+ dit_handler = AceStepHandler() # DiT handler
39
+ llm_handler = LLMHandler() # LM handler
40
+
41
  dataset_handler = DatasetHandler() # Dataset handler
42
 
43
+ # Create Gradio interface with all handlers and initialization parameters
44
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params)
45
 
46
  return demo
47
 
 
55
  parser.add_argument("--share", action="store_true", help="Create a public link")
56
  parser.add_argument("--debug", action="store_true", help="Enable debug mode")
57
  parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
58
+
59
+ # Service initialization arguments
60
+ parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
61
+ parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
62
+ parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
63
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
64
+ parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
65
+ parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
66
+ parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
67
+ parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
68
+ parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload models to CPU (default: False)")
69
+ parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
70
+
71
  args = parser.parse_args()
72
 
73
  try:
74
+ init_params = None
75
+
76
+ # If init_service is True, perform initialization before creating UI
77
+ if args.init_service:
78
+ print("Initializing service from command line...")
79
+
80
+ # Create handler instances for initialization
81
+ dit_handler = AceStepHandler()
82
+ llm_handler = LLMHandler()
83
+
84
+ # Auto-select config_path if not provided
85
+ if args.config_path is None:
86
+ available_models = dit_handler.get_available_acestep_v15_models()
87
+ if available_models:
88
+ args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
89
+ print(f"Auto-selected config_path: {args.config_path}")
90
+ else:
91
+ print("Error: No available models found. Please specify --config_path", file=sys.stderr)
92
+ sys.exit(1)
93
+
94
+ # Get project root (same logic as in handler)
95
+ current_file = os.path.abspath(__file__)
96
+ project_root = os.path.dirname(os.path.dirname(current_file))
97
+
98
+ # Determine flash attention setting
99
+ use_flash_attention = args.use_flash_attention
100
+ if use_flash_attention is None:
101
+ use_flash_attention = dit_handler.is_flash_attention_available()
102
+
103
+ # Initialize DiT handler
104
+ print(f"Initializing DiT model: {args.config_path} on {args.device}...")
105
+ init_status, enable_generate = dit_handler.initialize_service(
106
+ project_root=project_root,
107
+ config_path=args.config_path,
108
+ device=args.device,
109
+ use_flash_attention=use_flash_attention,
110
+ compile_model=False,
111
+ offload_to_cpu=args.offload_to_cpu,
112
+ offload_dit_to_cpu=args.offload_dit_to_cpu
113
+ )
114
+
115
+ if not enable_generate:
116
+ print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
117
+ sys.exit(1)
118
+
119
+ print(f"DiT model initialized successfully")
120
+
121
+ # Initialize LM handler if requested
122
+ lm_status = ""
123
+ if args.init_llm:
124
+ if args.lm_model_path is None:
125
+ # Try to get default LM model
126
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
127
+ if available_lm_models:
128
+ args.lm_model_path = available_lm_models[0]
129
+ print(f"Using default LM model: {args.lm_model_path}")
130
+ else:
131
+ print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
132
+ args.init_llm = False
133
+
134
+ if args.init_llm and args.lm_model_path:
135
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
136
+ print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
137
+ lm_status, lm_success = llm_handler.initialize(
138
+ checkpoint_dir=checkpoint_dir,
139
+ lm_model_path=args.lm_model_path,
140
+ backend=args.backend,
141
+ device=args.device,
142
+ offload_to_cpu=args.offload_to_cpu,
143
+ dtype=dit_handler.dtype
144
+ )
145
+
146
+ if lm_success:
147
+ print(f"5Hz LM initialized successfully")
148
+ init_status += f"\n{lm_status}"
149
+ else:
150
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
151
+ init_status += f"\n{lm_status}"
152
+
153
+ # Prepare initialization parameters for UI
154
+ init_params = {
155
+ 'pre_initialized': True,
156
+ 'checkpoint': args.checkpoint,
157
+ 'config_path': args.config_path,
158
+ 'device': args.device,
159
+ 'init_llm': args.init_llm,
160
+ 'lm_model_path': args.lm_model_path,
161
+ 'backend': args.backend,
162
+ 'use_flash_attention': use_flash_attention,
163
+ 'offload_to_cpu': args.offload_to_cpu,
164
+ 'offload_dit_to_cpu': args.offload_dit_to_cpu,
165
+ 'init_status': init_status,
166
+ 'enable_generate': enable_generate,
167
+ 'dit_handler': dit_handler,
168
+ 'llm_handler': llm_handler
169
+ }
170
+
171
+ print("Service initialization completed successfully!")
172
+
173
  # Create and launch demo
174
  print("Creating Gradio interface...")
175
+ demo = create_demo(init_params=init_params)
176
  print(f"Launching server on {args.server_name}:{args.port}...")
177
  demo.launch(
178
  server_name=args.server_name,
acestep/gradio_ui.py CHANGED
@@ -7,7 +7,7 @@ import gradio as gr
7
  from typing import Callable, Optional
8
 
9
 
10
- def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blocks:
11
  """
12
  Create Gradio interface
13
 
@@ -15,6 +15,8 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blo
15
  dit_handler: DiT handler instance
16
  llm_handler: LM handler instance
17
  dataset_handler: Dataset handler instance
 
 
18
 
19
  Returns:
20
  Gradio Blocks instance
@@ -47,8 +49,8 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blo
47
  # Dataset Explorer Section
48
  dataset_section = create_dataset_section(dataset_handler)
49
 
50
- # Generation Section
51
- generation_section = create_generation_section(dit_handler, llm_handler)
52
 
53
  # Results Section
54
  results_section = create_results_section(dit_handler)
@@ -156,20 +158,33 @@ def create_dataset_section(dataset_handler) -> dict:
156
  }
157
 
158
 
159
- def create_generation_section(dit_handler, llm_handler) -> dict:
160
- """Create generation section"""
 
 
 
 
 
 
 
 
 
 
161
  with gr.Group():
162
  gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
163
 
164
- # Service Configuration
165
- with gr.Accordion("🔧 Service Configuration", open=True) as service_config_accordion:
 
166
  # Dropdown options section - all dropdowns grouped together
167
  with gr.Row(equal_height=True):
168
  with gr.Column(scale=4):
 
 
169
  checkpoint_dropdown = gr.Dropdown(
170
  label="Checkpoint File",
171
  choices=dit_handler.get_available_checkpoints(),
172
- value=None,
173
  info="Select a trained model checkpoint file (full path or filename)"
174
  )
175
  with gr.Column(scale=1, min_width=90):
@@ -180,15 +195,19 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
180
  available_models = dit_handler.get_available_acestep_v15_models()
181
  default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
182
 
 
 
183
  config_path = gr.Dropdown(
184
  label="Main Model Path",
185
  choices=available_models,
186
- value=default_model,
187
  info="Select the model configuration directory (auto-scanned from checkpoints)"
188
  )
 
 
189
  device = gr.Dropdown(
190
  choices=["auto", "cuda", "cpu"],
191
- value="auto",
192
  label="Device",
193
  info="Processing device (auto-detect recommended)"
194
  )
@@ -198,47 +217,61 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
198
  available_lm_models = llm_handler.get_available_5hz_lm_models()
199
  default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
200
 
 
 
201
  lm_model_path = gr.Dropdown(
202
  label="5Hz LM Model Path",
203
  choices=available_lm_models,
204
- value=default_lm_model,
205
  info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
206
  )
 
 
207
  backend_dropdown = gr.Dropdown(
208
  choices=["vllm", "pt"],
209
- value="vllm",
210
  label="5Hz LM Backend",
211
  info="Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)"
212
  )
213
 
214
  # Checkbox options section - all checkboxes grouped together
215
  with gr.Row():
 
 
216
  init_llm_checkbox = gr.Checkbox(
217
  label="Initialize 5Hz LM",
218
- value=False,
219
  info="Check to initialize 5Hz LM during service initialization",
220
  )
221
  # Auto-detect flash attention availability
222
  flash_attn_available = dit_handler.is_flash_attention_available()
 
 
223
  use_flash_attention_checkbox = gr.Checkbox(
224
  label="Use Flash Attention",
225
- value=flash_attn_available,
226
  interactive=flash_attn_available,
227
  info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
228
  )
 
 
229
  offload_to_cpu_checkbox = gr.Checkbox(
230
  label="Offload to CPU",
231
- value=False,
232
  info="Offload models to CPU when not in use to save GPU memory"
233
  )
 
 
234
  offload_dit_to_cpu_checkbox = gr.Checkbox(
235
  label="Offload DiT to CPU",
236
- value=False,
237
  info="Offload DiT to CPU (needs Offload to CPU)"
238
  )
239
 
240
  init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
241
- init_status = gr.Textbox(label="Status", interactive=False, lines=3)
 
 
242
 
243
  # Inputs
244
  with gr.Row():
@@ -328,7 +361,7 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
328
  label="Temperature",
329
  minimum=0.0,
330
  maximum=2.0,
331
- value=0.7,
332
  step=0.1,
333
  scale=1,
334
  info="Temperature for 5Hz LM sampling (higher = more random, lower = more deterministic)"
@@ -337,18 +370,48 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
337
  label="CFG Scale",
338
  minimum=1.0,
339
  maximum=3.0,
340
- value=1.0,
341
  step=0.1,
342
  scale=1,
343
  info="Classifier-Free Guidance scale for 5Hz LM (1.0 = no CFG, higher = stronger guidance)"
344
  )
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  # Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
347
  lm_negative_prompt = gr.Textbox(
348
  label="Negative Prompt",
349
  value="NO USER INPUT",
350
  placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
351
- visible=False,
352
  info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
353
  lines=2
354
  )
@@ -377,7 +440,7 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
377
  step=0.01,
378
  label="Audio Cover Strength",
379
  info="Control how many denoising steps use cover mode",
380
- visible=False
381
  )
382
 
383
  # Music Caption
@@ -514,7 +577,9 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
514
  interactive=False
515
  )
516
 
517
- generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=False)
 
 
518
 
519
  return {
520
  "checkpoint_dropdown": checkpoint_dropdown,
@@ -542,6 +607,9 @@ def create_generation_section(dit_handler, llm_handler) -> dict:
542
  "use_5hz_lm_btn": use_5hz_lm_btn,
543
  "lm_temperature": lm_temperature,
544
  "lm_cfg_scale": lm_cfg_scale,
 
 
 
545
  "lm_negative_prompt": lm_negative_prompt,
546
  "repainting_group": repainting_group,
547
  "repainting_start": repainting_start,
@@ -733,6 +801,47 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
733
 
734
  return status, gr.update(interactive=enable)
735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  generation_section["init_btn"].click(
737
  fn=init_service_wrapper,
738
  inputs=[
@@ -749,30 +858,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
749
  outputs=[generation_section["init_status"], generation_section["generate_btn"]]
750
  )
751
 
752
- # Update negative prompt visibility based on LM initialization and CFG scale
753
- def update_negative_prompt_visibility(init_status, cfg_scale):
754
- """Update negative prompt visibility: show only if LM initialized and cfg_scale > 1"""
755
- # Check if LM is initialized by looking for "5Hz LM backend:" in status
756
- lm_initialized = init_status is not None and "5Hz LM backend:" in str(init_status)
757
- # Check if cfg_scale > 1
758
- cfg_enabled = cfg_scale is not None and float(cfg_scale) > 1.0
759
- # Show only if both conditions are met
760
- return gr.update(visible=lm_initialized and cfg_enabled)
761
-
762
- # Update visibility when init_status changes
763
- generation_section["init_status"].change(
764
- fn=update_negative_prompt_visibility,
765
- inputs=[generation_section["init_status"], generation_section["lm_cfg_scale"]],
766
- outputs=[generation_section["lm_negative_prompt"]]
767
- )
768
-
769
- # Update visibility when cfg_scale changes
770
- generation_section["lm_cfg_scale"].change(
771
- fn=update_negative_prompt_visibility,
772
- inputs=[generation_section["init_status"], generation_section["lm_cfg_scale"]],
773
- outputs=[generation_section["lm_negative_prompt"]]
774
- )
775
-
776
  # Generation with progress bar
777
  def generate_with_progress(
778
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
@@ -845,9 +930,16 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
845
  )
846
 
847
  # 5Hz LM generation (simplified version, can be extended as needed)
848
- def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, negative_prompt):
849
  """Wrapper for 5Hz LM generation"""
850
- metadata, audio_codes, status = llm_handler.generate_with_5hz_lm(caption, lyrics, temperature, cfg_scale, negative_prompt)
 
 
 
 
 
 
 
851
 
852
  # Extract metadata values and map to UI fields
853
  # Handle bpm
@@ -886,6 +978,9 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
886
  generation_section["lyrics"],
887
  generation_section["lm_temperature"],
888
  generation_section["lm_cfg_scale"],
 
 
 
889
  generation_section["lm_negative_prompt"]
890
  ],
891
  outputs=[
@@ -902,7 +997,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
902
  task_type_value: str,
903
  track_name_value: Optional[str],
904
  complete_track_classes_value: list,
905
- audio_codes_content: str = ""
 
906
  ) -> tuple:
907
  """Update instruction and UI visibility based on task type."""
908
  instruction = dit_handler.generate_instruction(
@@ -915,8 +1011,15 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
915
  track_name_visible = task_type_value in ["lego", "extract"]
916
  # Show complete_track_classes for complete
917
  complete_visible = task_type_value == "complete"
918
- # Show audio_cover_strength for cover
919
- audio_cover_strength_visible = task_type_value == "cover"
 
 
 
 
 
 
 
920
  # Show audio_code_string for cover
921
  audio_code_visible = task_type_value == "cover"
922
  # Show repainting controls for repaint and lego
@@ -932,7 +1035,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
932
  instruction, # instruction_display_gen
933
  gr.update(visible=track_name_visible), # track_name
934
  gr.update(visible=complete_visible), # complete_track_classes
935
- gr.update(visible=audio_cover_strength_visible), # audio_cover_strength
936
  gr.update(visible=repainting_visible), # repainting_group
937
  gr.update(visible=audio_code_visible), # audio_code_string
938
  gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
@@ -946,7 +1049,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
946
  generation_section["task_type"],
947
  generation_section["track_name"],
948
  generation_section["complete_track_classes"],
949
- generation_section["text2music_audio_code_string"]
 
950
  ],
951
  outputs=[
952
  generation_section["instruction_display_gen"],
@@ -967,7 +1071,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
967
  generation_section["task_type"],
968
  generation_section["track_name"],
969
  generation_section["complete_track_classes"],
970
- generation_section["text2music_audio_code_string"]
 
971
  ],
972
  outputs=[
973
  generation_section["instruction_display_gen"],
@@ -988,7 +1093,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
988
  generation_section["task_type"],
989
  generation_section["track_name"],
990
  generation_section["complete_track_classes"],
991
- generation_section["text2music_audio_code_string"]
 
992
  ],
993
  outputs=[
994
  generation_section["instruction_display_gen"],
 
7
  from typing import Callable, Optional
8
 
9
 
10
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None) -> gr.Blocks:
11
  """
12
  Create Gradio interface
13
 
 
15
  dit_handler: DiT handler instance
16
  llm_handler: LM handler instance
17
  dataset_handler: Dataset handler instance
18
+ init_params: Dictionary containing initialization parameters and state.
19
+ If None, service will not be pre-initialized.
20
 
21
  Returns:
22
  Gradio Blocks instance
 
49
  # Dataset Explorer Section
50
  dataset_section = create_dataset_section(dataset_handler)
51
 
52
+ # Generation Section (pass init_params to support pre-initialization)
53
+ generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params)
54
 
55
  # Results Section
56
  results_section = create_results_section(dit_handler)
 
158
  }
159
 
160
 
161
+ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dict:
162
+ """Create generation section
163
+
164
+ Args:
165
+ dit_handler: DiT handler instance
166
+ llm_handler: LM handler instance
167
+ init_params: Dictionary containing initialization parameters and state.
168
+ If None, service will not be pre-initialized.
169
+ """
170
+ # Check if service is pre-initialized
171
+ service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
172
+
173
  with gr.Group():
174
  gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
175
 
176
+ # Service Configuration - collapse if pre-initialized
177
+ accordion_open = not service_pre_initialized
178
+ with gr.Accordion("🔧 Service Configuration", open=accordion_open) as service_config_accordion:
179
  # Dropdown options section - all dropdowns grouped together
180
  with gr.Row(equal_height=True):
181
  with gr.Column(scale=4):
182
+ # Set checkpoint value from init_params if pre-initialized
183
+ checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
184
  checkpoint_dropdown = gr.Dropdown(
185
  label="Checkpoint File",
186
  choices=dit_handler.get_available_checkpoints(),
187
+ value=checkpoint_value,
188
  info="Select a trained model checkpoint file (full path or filename)"
189
  )
190
  with gr.Column(scale=1, min_width=90):
 
195
  available_models = dit_handler.get_available_acestep_v15_models()
196
  default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
197
 
198
+ # Set config_path value from init_params if pre-initialized
199
+ config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
200
  config_path = gr.Dropdown(
201
  label="Main Model Path",
202
  choices=available_models,
203
+ value=config_path_value,
204
  info="Select the model configuration directory (auto-scanned from checkpoints)"
205
  )
206
+ # Set device value from init_params if pre-initialized
207
+ device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
208
  device = gr.Dropdown(
209
  choices=["auto", "cuda", "cpu"],
210
+ value=device_value,
211
  label="Device",
212
  info="Processing device (auto-detect recommended)"
213
  )
 
217
  available_lm_models = llm_handler.get_available_5hz_lm_models()
218
  default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
219
 
220
+ # Set lm_model_path value from init_params if pre-initialized
221
+ lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
222
  lm_model_path = gr.Dropdown(
223
  label="5Hz LM Model Path",
224
  choices=available_lm_models,
225
+ value=lm_model_path_value,
226
  info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
227
  )
228
+ # Set backend value from init_params if pre-initialized
229
+ backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
230
  backend_dropdown = gr.Dropdown(
231
  choices=["vllm", "pt"],
232
+ value=backend_value,
233
  label="5Hz LM Backend",
234
  info="Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)"
235
  )
236
 
237
  # Checkbox options section - all checkboxes grouped together
238
  with gr.Row():
239
+ # Set init_llm value from init_params if pre-initialized
240
+ init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
241
  init_llm_checkbox = gr.Checkbox(
242
  label="Initialize 5Hz LM",
243
+ value=init_llm_value,
244
  info="Check to initialize 5Hz LM during service initialization",
245
  )
246
  # Auto-detect flash attention availability
247
  flash_attn_available = dit_handler.is_flash_attention_available()
248
+ # Set use_flash_attention value from init_params if pre-initialized
249
+ use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
250
  use_flash_attention_checkbox = gr.Checkbox(
251
  label="Use Flash Attention",
252
+ value=use_flash_attention_value,
253
  interactive=flash_attn_available,
254
  info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
255
  )
256
+ # Set offload_to_cpu value from init_params if pre-initialized
257
+ offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
258
  offload_to_cpu_checkbox = gr.Checkbox(
259
  label="Offload to CPU",
260
+ value=offload_to_cpu_value,
261
  info="Offload models to CPU when not in use to save GPU memory"
262
  )
263
+ # Set offload_dit_to_cpu value from init_params if pre-initialized
264
+ offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
265
  offload_dit_to_cpu_checkbox = gr.Checkbox(
266
  label="Offload DiT to CPU",
267
+ value=offload_dit_to_cpu_value,
268
  info="Offload DiT to CPU (needs Offload to CPU)"
269
  )
270
 
271
  init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
272
+ # Set init_status value from init_params if pre-initialized
273
+ init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
274
+ init_status = gr.Textbox(label="Status", interactive=False, lines=3, value=init_status_value)
275
 
276
  # Inputs
277
  with gr.Row():
 
361
  label="Temperature",
362
  minimum=0.0,
363
  maximum=2.0,
364
+ value=0.85,
365
  step=0.1,
366
  scale=1,
367
  info="Temperature for 5Hz LM sampling (higher = more random, lower = more deterministic)"
 
370
  label="CFG Scale",
371
  minimum=1.0,
372
  maximum=3.0,
373
+ value=2.0,
374
  step=0.1,
375
  scale=1,
376
  info="Classifier-Free Guidance scale for 5Hz LM (1.0 = no CFG, higher = stronger guidance)"
377
  )
378
 
379
+ with gr.Row():
380
+ lm_top_k = gr.Slider(
381
+ label="Top-K",
382
+ minimum=0,
383
+ maximum=100,
384
+ value=0,
385
+ step=1,
386
+ scale=1,
387
+ info="Top-K sampling: consider only top K tokens (0 = disabled)"
388
+ )
389
+ lm_top_p = gr.Slider(
390
+ label="Top-P",
391
+ minimum=0.0,
392
+ maximum=1.0,
393
+ value=0.9,
394
+ step=0.01,
395
+ scale=1,
396
+ info="Top-P (nucleus) sampling: cumulative probability threshold (1.0 = disabled)"
397
+ )
398
+ lm_repetition_penalty = gr.Slider(
399
+ label="Repetition Penalty",
400
+ minimum=0.8,
401
+ maximum=1.2,
402
+ value=1.0,
403
+ step=0.01,
404
+ scale=1,
405
+ info="Repetition penalty: >1.0 reduces repetition, <1.0 increases it (1.0 = no penalty). For audio generation, use 1.0 or very small values (1.01-1.05) as audio tokens naturally repeat.",
406
+ visible=False,
407
+ )
408
+
409
  # Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
410
  lm_negative_prompt = gr.Textbox(
411
  label="Negative Prompt",
412
  value="NO USER INPUT",
413
  placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
414
+ visible=True,
415
  info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
416
  lines=2
417
  )
 
440
  step=0.01,
441
  label="Audio Cover Strength",
442
  info="Control how many denoising steps use cover mode",
443
+ visible=True
444
  )
445
 
446
  # Music Caption
 
577
  interactive=False
578
  )
579
 
580
+ # Set generate_btn to interactive if service is pre-initialized
581
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
582
+ generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive)
583
 
584
  return {
585
  "checkpoint_dropdown": checkpoint_dropdown,
 
607
  "use_5hz_lm_btn": use_5hz_lm_btn,
608
  "lm_temperature": lm_temperature,
609
  "lm_cfg_scale": lm_cfg_scale,
610
+ "lm_top_k": lm_top_k,
611
+ "lm_top_p": lm_top_p,
612
+ "lm_repetition_penalty": lm_repetition_penalty,
613
  "lm_negative_prompt": lm_negative_prompt,
614
  "repainting_group": repainting_group,
615
  "repainting_start": repainting_start,
 
801
 
802
  return status, gr.update(interactive=enable)
803
 
804
+ # Update negative prompt visibility based on "Initialize 5Hz LM" checkbox
805
+ def update_negative_prompt_visibility(init_llm_checked):
806
+ """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
807
+ return gr.update(visible=init_llm_checked)
808
+
809
+ # Update audio_cover_strength visibility and label based on task type and LM initialization
810
+ def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
811
+ """Update audio_cover_strength visibility and label"""
812
+ # Show if task is cover OR if LM is initialized
813
+ is_visible = (task_type_value == "cover") or init_llm_checked
814
+ # Change label based on context
815
+ if init_llm_checked and task_type_value != "cover":
816
+ label = "LM codes strength"
817
+ info = "Control how many denoising steps use LM-generated codes"
818
+ else:
819
+ label = "Audio Cover Strength"
820
+ info = "Control how many denoising steps use cover mode"
821
+
822
+ return gr.update(visible=is_visible, label=label, info=info)
823
+
824
+ # Update visibility when init_llm_checkbox changes
825
+ generation_section["init_llm_checkbox"].change(
826
+ fn=update_negative_prompt_visibility,
827
+ inputs=[generation_section["init_llm_checkbox"]],
828
+ outputs=[generation_section["lm_negative_prompt"]]
829
+ )
830
+
831
+ # Update audio_cover_strength visibility and label when init_llm_checkbox changes
832
+ generation_section["init_llm_checkbox"].change(
833
+ fn=update_audio_cover_strength_visibility,
834
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
835
+ outputs=[generation_section["audio_cover_strength"]]
836
+ )
837
+
838
+ # Also update audio_cover_strength when task_type changes (to handle label changes)
839
+ generation_section["task_type"].change(
840
+ fn=update_audio_cover_strength_visibility,
841
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
842
+ outputs=[generation_section["audio_cover_strength"]]
843
+ )
844
+
845
  generation_section["init_btn"].click(
846
  fn=init_service_wrapper,
847
  inputs=[
 
858
  outputs=[generation_section["init_status"], generation_section["generate_btn"]]
859
  )
860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
  # Generation with progress bar
862
  def generate_with_progress(
863
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
 
930
  )
931
 
932
  # 5Hz LM generation (simplified version, can be extended as needed)
933
+ def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, top_k, top_p, repetition_penalty, negative_prompt):
934
  """Wrapper for 5Hz LM generation"""
935
+ # Convert top_k: 0 means None (disabled)
936
+ top_k_value = None if top_k == 0 else int(top_k)
937
+ # Convert top_p: 1.0 means None (disabled)
938
+ top_p_value = None if top_p >= 1.0 else top_p
939
+ metadata, audio_codes, status = llm_handler.generate_with_5hz_lm(
940
+ caption, lyrics, temperature, cfg_scale, negative_prompt,
941
+ top_k_value, top_p_value, repetition_penalty
942
+ )
943
 
944
  # Extract metadata values and map to UI fields
945
  # Handle bpm
 
978
  generation_section["lyrics"],
979
  generation_section["lm_temperature"],
980
  generation_section["lm_cfg_scale"],
981
+ generation_section["lm_top_k"],
982
+ generation_section["lm_top_p"],
983
+ generation_section["lm_repetition_penalty"],
984
  generation_section["lm_negative_prompt"]
985
  ],
986
  outputs=[
 
997
  task_type_value: str,
998
  track_name_value: Optional[str],
999
  complete_track_classes_value: list,
1000
+ audio_codes_content: str = "",
1001
+ init_llm_checked: bool = False
1002
  ) -> tuple:
1003
  """Update instruction and UI visibility based on task type."""
1004
  instruction = dit_handler.generate_instruction(
 
1011
  track_name_visible = task_type_value in ["lego", "extract"]
1012
  # Show complete_track_classes for complete
1013
  complete_visible = task_type_value == "complete"
1014
+ # Show audio_cover_strength for cover OR when LM is initialized
1015
+ audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked
1016
+ # Determine label and info based on context
1017
+ if init_llm_checked and task_type_value != "cover":
1018
+ audio_cover_strength_label = "LM codes strength"
1019
+ audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
1020
+ else:
1021
+ audio_cover_strength_label = "Audio Cover Strength"
1022
+ audio_cover_strength_info = "Control how many denoising steps use cover mode"
1023
  # Show audio_code_string for cover
1024
  audio_code_visible = task_type_value == "cover"
1025
  # Show repainting controls for repaint and lego
 
1035
  instruction, # instruction_display_gen
1036
  gr.update(visible=track_name_visible), # track_name
1037
  gr.update(visible=complete_visible), # complete_track_classes
1038
+ gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
1039
  gr.update(visible=repainting_visible), # repainting_group
1040
  gr.update(visible=audio_code_visible), # audio_code_string
1041
  gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
 
1049
  generation_section["task_type"],
1050
  generation_section["track_name"],
1051
  generation_section["complete_track_classes"],
1052
+ generation_section["text2music_audio_code_string"],
1053
+ generation_section["init_llm_checkbox"]
1054
  ],
1055
  outputs=[
1056
  generation_section["instruction_display_gen"],
 
1071
  generation_section["task_type"],
1072
  generation_section["track_name"],
1073
  generation_section["complete_track_classes"],
1074
+ generation_section["text2music_audio_code_string"],
1075
+ generation_section["init_llm_checkbox"]
1076
  ],
1077
  outputs=[
1078
  generation_section["instruction_display_gen"],
 
1093
  generation_section["task_type"],
1094
  generation_section["track_name"],
1095
  generation_section["complete_track_classes"],
1096
+ generation_section["text2music_audio_code_string"],
1097
+ generation_section["init_llm_checkbox"]
1098
  ],
1099
  outputs=[
1100
  generation_section["instruction_display_gen"],
acestep/handler.py CHANGED
@@ -1362,7 +1362,7 @@ class AceStepHandler:
1362
 
1363
  padded_non_cover_text_input_ids = None
1364
  padded_non_cover_text_attention_masks = None
1365
- if audio_cover_strength < 1.0 and is_covers is not None and is_covers.any():
1366
  non_cover_text_input_ids = []
1367
  non_cover_text_attention_masks = []
1368
  for i in range(batch_size):
@@ -1381,8 +1381,9 @@ class AceStepHandler:
1381
  return_tensors="pt",
1382
  )
1383
  text_token_ids = text_inputs_dict.input_ids[0]
 
1384
  non_cover_text_input_ids.append(text_token_ids)
1385
- non_cover_text_attention_masks.append(text_attention_mask)
1386
 
1387
  padded_non_cover_text_input_ids = torch.stack([
1388
  torch.nn.functional.pad(
 
1362
 
1363
  padded_non_cover_text_input_ids = None
1364
  padded_non_cover_text_attention_masks = None
1365
+ if audio_cover_strength < 1.0:
1366
  non_cover_text_input_ids = []
1367
  non_cover_text_attention_masks = []
1368
  for i in range(batch_size):
 
1381
  return_tensors="pt",
1382
  )
1383
  text_token_ids = text_inputs_dict.input_ids[0]
1384
+ non_cover_text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1385
  non_cover_text_input_ids.append(text_token_ids)
1386
+ non_cover_text_attention_masks.append(non_cover_text_attention_mask)
1387
 
1388
  padded_non_cover_text_input_ids = torch.stack([
1389
  torch.nn.functional.pad(
acestep/llm_inference.py CHANGED
@@ -11,8 +11,18 @@ from contextlib import contextmanager
11
  import torch
12
  from tqdm import tqdm
13
  from loguru import logger
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
15
  from transformers.generation.streamers import BaseStreamer
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class LLMHandler:
@@ -209,7 +219,17 @@ class LLMHandler:
209
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
210
  return error_msg
211
 
212
- def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
 
 
 
 
 
 
 
 
 
 
213
  """Generate metadata and audio codes using 5Hz LM with vllm backend"""
214
  try:
215
  from nanovllm import SamplingParams
@@ -226,7 +246,14 @@ class LLMHandler:
226
  )
227
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
228
 
229
- sampling_params = SamplingParams(max_tokens=self.max_model_len-64, temperature=temperature, cfg_scale=cfg_scale)
 
 
 
 
 
 
 
230
  # Use CFG if cfg_scale > 1.0
231
  if cfg_scale > 1.0:
232
  # Build unconditional prompt (user input replaced with "NO USER INPUT")
@@ -266,7 +293,17 @@ class LLMHandler:
266
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
267
  return {}, "", error_msg
268
 
269
- def generate_with_5hz_lm_pt(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
 
 
 
 
 
 
 
 
 
 
270
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
271
  try:
272
  prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
@@ -295,7 +332,7 @@ class LLMHandler:
295
  # Get max_new_tokens from model config or use a default
296
  max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
297
  if hasattr(self, 'max_model_len'):
298
- max_new_tokens = min(max_new_tokens, self.max_model_len)
299
 
300
  # Define custom streamer for tqdm
301
  class TqdmTokenStreamer(BaseStreamer):
@@ -315,15 +352,78 @@ class LLMHandler:
315
 
316
  streamer = TqdmTokenStreamer(total=max_new_tokens)
317
 
318
- with torch.no_grad():
319
- outputs = self.llm.generate(
320
- **inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  max_new_tokens=max_new_tokens,
322
  temperature=temperature,
323
- do_sample=True if temperature > 0 else False,
 
324
  pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
325
  streamer=streamer,
326
  )
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  # Decode the generated tokens
329
  # Only decode the newly generated tokens (skip the input prompt)
@@ -338,7 +438,17 @@ class LLMHandler:
338
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
339
  return {}, "", error_msg
340
 
341
- def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
 
 
 
 
 
 
 
 
 
 
342
  """Generate metadata and audio codes using 5Hz LM"""
343
  # Check if 5Hz LM is initialized
344
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
@@ -355,9 +465,15 @@ class LLMHandler:
355
  return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
356
 
357
  if self.llm_backend == "vllm":
358
- return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature, cfg_scale, negative_prompt)
 
 
 
359
  else:
360
- return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
 
 
 
361
 
362
  def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
363
  """
@@ -440,6 +556,112 @@ class LLMHandler:
440
 
441
  return metadata, audio_codes
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  @contextmanager
444
  def _load_model_context(self):
445
  """
 
11
  import torch
12
  from tqdm import tqdm
13
  from loguru import logger
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, ClassifierFreeGuidanceLogitsProcessor
15
  from transformers.generation.streamers import BaseStreamer
16
+ from transformers.generation.logits_process import (
17
+ LogitsProcessorList,
18
+ LogitsProcessor,
19
+ TopKLogitsWarper,
20
+ TopPLogitsWarper,
21
+ RepetitionPenaltyLogitsProcessor,
22
+ TemperatureLogitsWarper,
23
+ )
24
+
25
+
26
 
27
 
28
  class LLMHandler:
 
219
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
220
  return error_msg
221
 
222
+ def generate_with_5hz_lm_vllm(
223
+ self,
224
+ caption: str,
225
+ lyrics: str,
226
+ temperature: float = 0.6,
227
+ cfg_scale: float = 1.0,
228
+ negative_prompt: str = "NO USER INPUT",
229
+ top_k: Optional[int] = None,
230
+ top_p: Optional[float] = None,
231
+ repetition_penalty: float = 1.0,
232
+ ) -> Tuple[Dict[str, Any], str, str]:
233
  """Generate metadata and audio codes using 5Hz LM with vllm backend"""
234
  try:
235
  from nanovllm import SamplingParams
 
246
  )
247
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
248
 
249
+ sampling_params = SamplingParams(
250
+ max_tokens=self.max_model_len-64,
251
+ temperature=temperature,
252
+ cfg_scale=cfg_scale,
253
+ top_k=top_k,
254
+ top_p=top_p,
255
+ repetition_penalty=repetition_penalty,
256
+ )
257
  # Use CFG if cfg_scale > 1.0
258
  if cfg_scale > 1.0:
259
  # Build unconditional prompt (user input replaced with "NO USER INPUT")
 
293
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
294
  return {}, "", error_msg
295
 
296
+ def generate_with_5hz_lm_pt(
297
+ self,
298
+ caption: str,
299
+ lyrics: str,
300
+ temperature: float = 0.6,
301
+ cfg_scale: float = 1.0,
302
+ negative_prompt: str = "NO USER INPUT",
303
+ top_k: Optional[int] = None,
304
+ top_p: Optional[float] = None,
305
+ repetition_penalty: float = 1.0,
306
+ ) -> Tuple[Dict[str, Any], str, str]:
307
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
308
  try:
309
  prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
 
332
  # Get max_new_tokens from model config or use a default
333
  max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
334
  if hasattr(self, 'max_model_len'):
335
+ max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
336
 
337
  # Define custom streamer for tqdm
338
  class TqdmTokenStreamer(BaseStreamer):
 
352
 
353
  streamer = TqdmTokenStreamer(total=max_new_tokens)
354
 
355
+ # Build logits processor list
356
+ logits_processor = LogitsProcessorList()
357
+
358
+ # Add repetition penalty if needed
359
+ if repetition_penalty != 1.0:
360
+ logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
361
+
362
+ # Add temperature warper if needed (temperature is handled separately in generate, but we can also use warper)
363
+ # Note: temperature is passed directly to generate(), but we can use TemperatureLogitsWarper for consistency
364
+ if temperature != 1.0:
365
+ logits_processor.append(TemperatureLogitsWarper(temperature=temperature))
366
+
367
+ # Add top-k warper if specified
368
+ if top_k is not None and top_k > 0:
369
+ logits_processor.append(TopKLogitsWarper(top_k=top_k))
370
+
371
+ # Add top-p warper if specified
372
+ if top_p is not None and top_p > 0.0 and top_p < 1.0:
373
+ logits_processor.append(TopPLogitsWarper(top_p=top_p))
374
+
375
+ # Handle CFG if cfg_scale > 1.0
376
+ if cfg_scale > 1.0:
377
+ # Build unconditional prompt
378
+ formatted_unconditional_prompt = self.llm_tokenizer.apply_chat_template(
379
+ [
380
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
381
+ {"role": "user", "content": negative_prompt}
382
+ ],
383
+ tokenize=False,
384
+ add_generation_prompt=True,
385
+ )
386
+
387
+ # Tokenize unconditional prompt
388
+ uncond_inputs = self.llm_tokenizer(
389
+ formatted_unconditional_prompt,
390
+ return_tensors="pt",
391
+ padding=False,
392
+ truncation=True,
393
+ )
394
+ uncond_inputs = {k: v.to(self.device) for k, v in uncond_inputs.items()}
395
+
396
+ # Use custom CFG generation with batch processing
397
+ # Combine conditional and unconditional inputs into a batch
398
+ # Format: [cond_input, uncond_input]
399
+ batch_input_ids = torch.cat([inputs['input_ids'], uncond_inputs['input_ids']], dim=0)
400
+ batch_attention_mask = None
401
+ if 'attention_mask' in inputs:
402
+ batch_attention_mask = torch.cat([inputs['attention_mask'], uncond_inputs.get('attention_mask', torch.ones_like(uncond_inputs['input_ids']))], dim=0)
403
+
404
+ # Custom CFG generation loop
405
+ outputs = self._generate_with_cfg(
406
+ batch_input_ids=batch_input_ids,
407
+ batch_attention_mask=batch_attention_mask,
408
  max_new_tokens=max_new_tokens,
409
  temperature=temperature,
410
+ cfg_scale=cfg_scale,
411
+ logits_processor=logits_processor,
412
  pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
413
  streamer=streamer,
414
  )
415
+ else:
416
+ # Generate without CFG
417
+ with torch.no_grad():
418
+ outputs = self.llm.generate(
419
+ **inputs,
420
+ max_new_tokens=max_new_tokens,
421
+ temperature=temperature if temperature > 0 else 1.0,
422
+ do_sample=True if temperature > 0 else False,
423
+ logits_processor=logits_processor if len(logits_processor) > 0 else None,
424
+ pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
425
+ streamer=streamer,
426
+ )
427
 
428
  # Decode the generated tokens
429
  # Only decode the newly generated tokens (skip the input prompt)
 
438
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
439
  return {}, "", error_msg
440
 
441
+ def generate_with_5hz_lm(
442
+ self,
443
+ caption: str,
444
+ lyrics: str,
445
+ temperature: float = 0.6,
446
+ cfg_scale: float = 1.0,
447
+ negative_prompt: str = "NO USER INPUT",
448
+ top_k: Optional[int] = None,
449
+ top_p: Optional[float] = None,
450
+ repetition_penalty: float = 1.0,
451
+ ) -> Tuple[Dict[str, Any], str, str]:
452
  """Generate metadata and audio codes using 5Hz LM"""
453
  # Check if 5Hz LM is initialized
454
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
 
465
  return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
466
 
467
  if self.llm_backend == "vllm":
468
+ return self.generate_with_5hz_lm_vllm(
469
+ caption, lyrics, temperature, cfg_scale, negative_prompt,
470
+ top_k, top_p, repetition_penalty
471
+ )
472
  else:
473
+ return self.generate_with_5hz_lm_pt(
474
+ caption, lyrics, temperature, cfg_scale, negative_prompt,
475
+ top_k, top_p, repetition_penalty
476
+ )
477
 
478
  def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
479
  """
 
556
 
557
  return metadata, audio_codes
558
 
559
+ def _generate_with_cfg(
560
+ self,
561
+ batch_input_ids: torch.Tensor,
562
+ batch_attention_mask: Optional[torch.Tensor],
563
+ max_new_tokens: int,
564
+ temperature: float,
565
+ cfg_scale: float,
566
+ logits_processor: Optional[LogitsProcessorList],
567
+ pad_token_id: int,
568
+ streamer: Optional[BaseStreamer],
569
+ ) -> torch.Tensor:
570
+ """
571
+ Custom generation loop with CFG support using batch processing.
572
+ Batch format: [conditional_input, unconditional_input]
573
+ This properly utilizes KV cache by processing both sequences in parallel.
574
+ """
575
+ model = self.llm
576
+ device = self.device
577
+ batch_size = batch_input_ids.shape[0] // 2 # Half are conditional, half are unconditional
578
+ cond_start_idx = 0
579
+ uncond_start_idx = batch_size
580
+
581
+ # Initialize generated sequences
582
+ generated_ids = batch_input_ids.clone()
583
+ if batch_attention_mask is not None:
584
+ attention_mask = batch_attention_mask.clone()
585
+ else:
586
+ attention_mask = torch.ones_like(batch_input_ids)
587
+
588
+ # Prepare model inputs
589
+ model_kwargs = {}
590
+ if batch_attention_mask is not None:
591
+ model_kwargs['attention_mask'] = attention_mask
592
+
593
+ # Past key values for KV cache (if model supports it)
594
+ past_key_values = None
595
+ use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
596
+
597
+ with torch.no_grad():
598
+ for step in range(max_new_tokens):
599
+ # Forward pass for the entire batch (conditional + unconditional)
600
+ if past_key_values is None:
601
+ # First step: full forward pass
602
+ outputs = model(
603
+ input_ids=generated_ids,
604
+ **model_kwargs,
605
+ use_cache=use_cache,
606
+ )
607
+ else:
608
+ # Subsequent steps: only forward the last token (utilizing KV cache)
609
+ outputs = model(
610
+ input_ids=generated_ids[:, -1:],
611
+ past_key_values=past_key_values,
612
+ **model_kwargs,
613
+ use_cache=use_cache,
614
+ )
615
+
616
+ # Get logits
617
+ next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
618
+
619
+ # Split conditional and unconditional logits
620
+ cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
621
+ uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
622
+
623
+ # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
624
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
625
+
626
+ # Apply logits processors (temperature, top-k, top-p, repetition penalty)
627
+ if logits_processor is not None:
628
+ # Get current input_ids for repetition penalty (only conditional part)
629
+ current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
630
+ for processor in logits_processor:
631
+ cfg_logits = processor(current_input_ids, cfg_logits)
632
+
633
+ # Apply temperature and sample
634
+ if temperature > 0:
635
+ cfg_logits = cfg_logits / temperature
636
+ probs = torch.softmax(cfg_logits, dim=-1)
637
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
638
+ else:
639
+ next_tokens = torch.argmax(cfg_logits, dim=-1)
640
+
641
+ # Update generated sequences (apply same token to both conditional and unconditional)
642
+ next_tokens = next_tokens.unsqueeze(1)
643
+ generated_ids = torch.cat([generated_ids, next_tokens.repeat(2, 1)], dim=1)
644
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
645
+ model_kwargs['attention_mask'] = attention_mask
646
+
647
+ # Update past_key_values for next iteration
648
+ if use_cache and hasattr(outputs, 'past_key_values'):
649
+ past_key_values = outputs.past_key_values
650
+
651
+ # Update streamer
652
+ if streamer is not None:
653
+ streamer.put(next_tokens[0]) # Only stream conditional tokens
654
+
655
+ # Check for EOS (simplified - you may want to check model's eos_token_id)
656
+ if (next_tokens[0] == pad_token_id).all():
657
+ break
658
+
659
+ if streamer is not None:
660
+ streamer.end()
661
+
662
+ # Return only conditional output
663
+ return generated_ids[cond_start_idx:cond_start_idx+batch_size]
664
+
665
  @contextmanager
666
  def _load_model_context(self):
667
  """
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -212,22 +212,37 @@ class ModelRunner:
212
  """Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
213
  if is_cfg_batch:
214
  # For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
215
- # We only need temperatures for conditional sequences (first half)
216
  num_cond = len(seqs) // 2
217
  temperatures = []
218
  cfg_scales = []
 
 
 
219
  for seq in seqs[:num_cond]:
220
  temperatures.append(seq.temperature)
221
  cfg_scales.append(seq.cfg_scale)
 
 
 
222
  else:
223
  temperatures = []
224
  cfg_scales = []
 
 
 
225
  for seq in seqs:
226
  temperatures.append(seq.temperature)
227
  cfg_scales.append(seq.cfg_scale)
 
 
 
228
  temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
229
  cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
230
- return temperatures, cfg_scales
 
 
 
231
 
232
  @torch.inference_mode()
233
  def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
@@ -274,7 +289,11 @@ class ModelRunner:
274
  # Prepare inputs for both conditional and unconditional (they're already in the batch)
275
  input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
276
  else self.prepare_decode(seqs))
277
- temperatures, cfg_scales = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else (None, None)
 
 
 
 
278
 
279
  # Run model forward (processes entire batch: cond + uncond)
280
  logits_all = self.run_model(input_ids, positions, is_prefill)
@@ -285,12 +304,44 @@ class ModelRunner:
285
  logits_cond = logits_all[:num_cond]
286
  logits_uncond = logits_all[num_cond:]
287
 
288
- # Apply CFG formula: logits_cfg = logits_cond + cfg_scale * (logits_cond - logits_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
290
- logits_cfg = logits_cond + cfg_scales_tensor * (logits_cond - logits_uncond)
 
 
 
291
 
292
  # Sample from CFG logits
293
- token_ids_cfg = self.sampler(logits_cfg, temperatures).tolist()
 
 
 
 
 
 
 
294
 
295
  # Return token_ids (will be applied to both conditional and unconditional sequences)
296
  return token_ids_cfg
@@ -300,11 +351,51 @@ class ModelRunner:
300
  # Normal batch (non-CFG)
301
  input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
302
  else self.prepare_decode(seqs))
303
- temperatures, cfg_scales = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else (None, None)
 
 
 
 
304
  logits = self.run_model(input_ids, positions, is_prefill)
305
  reset_context()
306
- token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
307
- return token_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
  @torch.inference_mode()
310
  def capture_cudagraph(self):
 
212
  """Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
213
  if is_cfg_batch:
214
  # For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
215
+ # We only need parameters for conditional sequences (first half)
216
  num_cond = len(seqs) // 2
217
  temperatures = []
218
  cfg_scales = []
219
+ top_ks = []
220
+ top_ps = []
221
+ repetition_penalties = []
222
  for seq in seqs[:num_cond]:
223
  temperatures.append(seq.temperature)
224
  cfg_scales.append(seq.cfg_scale)
225
+ top_ks.append(seq.top_k if seq.top_k is not None else 0)
226
+ top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
227
+ repetition_penalties.append(seq.repetition_penalty)
228
  else:
229
  temperatures = []
230
  cfg_scales = []
231
+ top_ks = []
232
+ top_ps = []
233
+ repetition_penalties = []
234
  for seq in seqs:
235
  temperatures.append(seq.temperature)
236
  cfg_scales.append(seq.cfg_scale)
237
+ top_ks.append(seq.top_k if seq.top_k is not None else 0)
238
+ top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
239
+ repetition_penalties.append(seq.repetition_penalty)
240
  temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
241
  cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
242
+ top_ks = torch.tensor(top_ks, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
243
+ top_ps = torch.tensor(top_ps, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
244
+ repetition_penalties = torch.tensor(repetition_penalties, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
245
+ return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
246
 
247
  @torch.inference_mode()
248
  def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
 
289
  # Prepare inputs for both conditional and unconditional (they're already in the batch)
290
  input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
291
  else self.prepare_decode(seqs))
292
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
293
+ if sample_params is not None:
294
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
295
+ else:
296
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
297
 
298
  # Run model forward (processes entire batch: cond + uncond)
299
  logits_all = self.run_model(input_ids, positions, is_prefill)
 
304
  logits_cond = logits_all[:num_cond]
305
  logits_uncond = logits_all[num_cond:]
306
 
307
+ # Apply repetition penalty to conditional logits (before CFG)
308
+ if repetition_penalties is not None:
309
+ for i, seq in enumerate(cond_seqs):
310
+ penalty = repetition_penalties[i].item()
311
+ if penalty != 1.0:
312
+ # Only penalize completion tokens (not prompt tokens)
313
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
314
+ if len(completion_tokens) > 0:
315
+ # Create token mask: mark tokens that appeared in completion
316
+ token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
317
+ token_mask[completion_tokens] = True
318
+
319
+ # Apply standard repetition penalty formula (matching transformers implementation):
320
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
321
+ penalty_scores = torch.where(
322
+ logits_cond[i] < 0,
323
+ logits_cond[i] * penalty,
324
+ logits_cond[i] / penalty
325
+ )
326
+ # Only apply penalty to tokens that appeared in completion
327
+ logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
328
+
329
+ # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
330
  cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
331
+ logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
332
+
333
+ # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
334
+ cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
335
 
336
  # Sample from CFG logits
337
+ token_ids_cfg = self.sampler(
338
+ logits_cfg,
339
+ temperatures,
340
+ top_ks=top_ks if top_ks is not None else None,
341
+ top_ps=top_ps if top_ps is not None else None,
342
+ repetition_penalties=None, # Already applied above
343
+ input_ids=cond_input_ids,
344
+ ).tolist()
345
 
346
  # Return token_ids (will be applied to both conditional and unconditional sequences)
347
  return token_ids_cfg
 
351
  # Normal batch (non-CFG)
352
  input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
353
  else self.prepare_decode(seqs))
354
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
355
+ if sample_params is not None:
356
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
357
+ else:
358
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
359
  logits = self.run_model(input_ids, positions, is_prefill)
360
  reset_context()
361
+
362
+ if self.rank == 0:
363
+ # Apply repetition penalty to logits
364
+ if repetition_penalties is not None:
365
+ for i, seq in enumerate(seqs):
366
+ penalty = repetition_penalties[i].item()
367
+ if penalty != 1.0:
368
+ # Only penalize completion tokens (not prompt tokens)
369
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
370
+ if len(completion_tokens) > 0:
371
+ # Create token mask: mark tokens that appeared in completion
372
+ token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
373
+ token_mask[completion_tokens] = True
374
+
375
+ # Apply standard repetition penalty formula (matching transformers implementation):
376
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
377
+ penalty_scores = torch.where(
378
+ logits[i] < 0,
379
+ logits[i] * penalty,
380
+ logits[i] / penalty
381
+ )
382
+ # Only apply penalty to tokens that appeared in completion
383
+ logits[i] = torch.where(token_mask, penalty_scores, logits[i])
384
+
385
+ # Prepare input_ids for sampler
386
+ seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
387
+
388
+ token_ids = self.sampler(
389
+ logits,
390
+ temperatures,
391
+ top_ks=top_ks if top_ks is not None else None,
392
+ top_ps=top_ps if top_ps is not None else None,
393
+ repetition_penalties=None, # Already applied above
394
+ input_ids=seq_input_ids,
395
+ ).tolist()
396
+ return token_ids
397
+ else:
398
+ return None
399
 
400
  @torch.inference_mode()
401
  def capture_cudagraph(self):
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py CHANGED
@@ -28,6 +28,9 @@ class Sequence:
28
  self.max_tokens = sampling_params.max_tokens
29
  self.ignore_eos = sampling_params.ignore_eos
30
  self.cfg_scale = sampling_params.cfg_scale
 
 
 
31
  # For CFG: mark if this is an unconditional sequence
32
  self.is_unconditional = is_unconditional
33
  # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
 
28
  self.max_tokens = sampling_params.max_tokens
29
  self.ignore_eos = sampling_params.ignore_eos
30
  self.cfg_scale = sampling_params.cfg_scale
31
+ self.top_k = sampling_params.top_k
32
+ self.top_p = sampling_params.top_p
33
+ self.repetition_penalty = sampling_params.repetition_penalty
34
  # For CFG: mark if this is an unconditional sequence
35
  self.is_unconditional = is_unconditional
36
  # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from torch import nn
 
3
 
4
 
5
  class Sampler(nn.Module):
@@ -8,8 +9,66 @@ class Sampler(nn.Module):
8
  super().__init__()
9
 
10
  @torch.compile
11
- def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  probs = torch.softmax(logits, dim=-1)
14
  sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
15
  return sample_tokens
 
1
  import torch
2
  from torch import nn
3
+ from typing import Optional
4
 
5
 
6
  class Sampler(nn.Module):
 
9
  super().__init__()
10
 
11
  @torch.compile
12
+ def forward(
13
+ self,
14
+ logits: torch.Tensor,
15
+ temperatures: torch.Tensor,
16
+ top_ks: Optional[torch.Tensor] = None,
17
+ top_ps: Optional[torch.Tensor] = None,
18
+ repetition_penalties: Optional[torch.Tensor] = None,
19
+ input_ids: Optional[torch.Tensor] = None,
20
+ ):
21
+ """
22
+ Sample tokens from logits with optional top-k, top-p, and repetition penalty.
23
+
24
+ Args:
25
+ logits: [batch_size, vocab_size] logits tensor
26
+ temperatures: [batch_size] temperature values
27
+ top_ks: Optional [batch_size] top-k values (None or 0 means no top-k filtering)
28
+ top_ps: Optional [batch_size] top-p values (None or 1.0 means no top-p filtering)
29
+ repetition_penalties: Optional [batch_size] repetition penalty values (1.0 means no penalty)
30
+ input_ids: Optional [batch_size, seq_len] input token ids for repetition penalty
31
+ """
32
+ batch_size, vocab_size = logits.shape
33
+
34
+ # Note: Repetition penalty is applied in ModelRunner before calling sampler
35
+ # This allows us to use the full sequence context
36
+
37
+ # Apply temperature
38
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
39
+
40
+ # Apply top-k filtering if specified
41
+ if top_ks is not None:
42
+ for i in range(batch_size):
43
+ top_k = top_ks[i].item()
44
+ if top_k > 0 and top_k < vocab_size:
45
+ # Get top-k logits, set others to -inf
46
+ top_k_logits, top_k_indices = torch.topk(logits[i], int(top_k), dim=-1)
47
+ filtered_logits = torch.full_like(logits[i], float('-inf'))
48
+ filtered_logits[top_k_indices] = top_k_logits
49
+ logits[i] = filtered_logits
50
+
51
+ # Apply top-p (nucleus) filtering if specified
52
+ if top_ps is not None:
53
+ probs = torch.softmax(logits, dim=-1)
54
+ for i in range(batch_size):
55
+ top_p = top_ps[i].item()
56
+ if 0.0 < top_p < 1.0:
57
+ # Sort probabilities in descending order
58
+ sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
59
+ # Calculate cumulative probabilities
60
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
61
+ # Find the cutoff point
62
+ cutoff_idx = (cumsum_probs <= top_p).sum().item()
63
+ if cutoff_idx < len(sorted_indices):
64
+ cutoff_idx += 1 # Include one more token to ensure we have at least one
65
+ # Create mask for tokens to keep
66
+ mask = torch.zeros_like(probs[i])
67
+ mask[sorted_indices[:cutoff_idx]] = 1.0
68
+ # Apply mask: set filtered tokens to -inf
69
+ logits[i] = torch.where(mask > 0, logits[i], torch.tensor(float('-inf'), device=logits.device))
70
+
71
+ # Sample using Gumbel-max trick (equivalent to sampling from softmax)
72
  probs = torch.softmax(logits, dim=-1)
73
  sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
74
  return sample_tokens
acestep/third_parts/nano-vllm/nanovllm/sampling_params.py CHANGED
@@ -1,4 +1,5 @@
1
  from dataclasses import dataclass
 
2
 
3
 
4
  @dataclass
@@ -7,7 +8,15 @@ class SamplingParams:
7
  max_tokens: int = 64
8
  ignore_eos: bool = False
9
  cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
 
 
 
10
 
11
  def __post_init__(self):
12
  assert self.temperature > 1e-10, "greedy sampling is not permitted"
13
  assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
 
 
 
 
 
 
1
  from dataclasses import dataclass
2
+ from typing import Optional
3
 
4
 
5
  @dataclass
 
8
  max_tokens: int = 64
9
  ignore_eos: bool = False
10
  cfg_scale: float = 1.0 # CFG guidance scale. When > 1.0, applies classifier-free guidance
11
+ top_k: Optional[int] = None # Top-k sampling: consider only top k tokens
12
+ top_p: Optional[float] = None # Top-p (nucleus) sampling: consider tokens with cumulative probability <= top_p
13
+ repetition_penalty: float = 1.0 # Repetition penalty: >1.0 reduces repetition, <1.0 increases it
14
 
15
  def __post_init__(self):
16
  assert self.temperature > 1e-10, "greedy sampling is not permitted"
17
  assert self.cfg_scale >= 1.0, "cfg_scale must be >= 1.0"
18
+ if self.top_k is not None:
19
+ assert self.top_k > 0, "top_k must be > 0"
20
+ if self.top_p is not None:
21
+ assert 0.0 < self.top_p <= 1.0, "top_p must be in (0.0, 1.0]"
22
+ assert self.repetition_penalty > 0.0, "repetition_penalty must be > 0.0"