ChuxiJ commited on
Commit
a3b47b7
·
1 Parent(s): 2745dd3
acestep/gradio_ui.py CHANGED
@@ -160,17 +160,18 @@ def create_generation_section(handler) -> dict:
160
 
161
  # Service Configuration
162
  with gr.Accordion("🔧 Service Configuration", open=True) as service_config_accordion:
163
- with gr.Row():
164
- with gr.Column(scale=2):
 
165
  checkpoint_dropdown = gr.Dropdown(
166
  label="Checkpoint File",
167
  choices=handler.get_available_checkpoints(),
168
  value=None,
169
  info="Select a trained model checkpoint file (full path or filename)"
170
  )
171
- with gr.Column(scale=1):
172
  refresh_btn = gr.Button("🔄 Refresh", size="sm")
173
-
174
  with gr.Row():
175
  # Get available acestep-v15- model list
176
  available_models = handler.get_available_acestep_v15_models()
@@ -200,13 +201,20 @@ def create_generation_section(handler) -> dict:
200
  value=default_lm_model,
201
  info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
202
  )
 
 
 
 
 
 
 
 
 
203
  init_llm_checkbox = gr.Checkbox(
204
  label="Initialize 5Hz LM",
205
  value=False,
206
  info="Check to initialize 5Hz LM during service initialization",
207
  )
208
-
209
- with gr.Row():
210
  # Auto-detect flash attention availability
211
  flash_attn_available = handler.is_flash_attention_available()
212
  use_flash_attention_checkbox = gr.Checkbox(
@@ -223,7 +231,7 @@ def create_generation_section(handler) -> dict:
223
  offload_dit_to_cpu_checkbox = gr.Checkbox(
224
  label="Offload DiT to CPU",
225
  value=False,
226
- info="Offload DiT model to CPU when not in use (only effective if Offload to CPU is checked)"
227
  )
228
 
229
  init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
@@ -319,10 +327,29 @@ def create_generation_section(handler) -> dict:
319
  maximum=2.0,
320
  value=0.7,
321
  step=0.1,
322
- scale=2,
323
- info="Temperature for 5Hz LM sampling"
 
 
 
 
 
 
 
 
 
324
  )
325
 
 
 
 
 
 
 
 
 
 
 
326
  # Repainting controls
327
  with gr.Group(visible=False) as repainting_group:
328
  gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
@@ -495,6 +522,7 @@ def create_generation_section(handler) -> dict:
495
  "init_status": init_status,
496
  "lm_model_path": lm_model_path,
497
  "init_llm_checkbox": init_llm_checkbox,
 
498
  "use_flash_attention_checkbox": use_flash_attention_checkbox,
499
  "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
500
  "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
@@ -510,6 +538,8 @@ def create_generation_section(handler) -> dict:
510
  "use_5hz_lm_row": use_5hz_lm_row,
511
  "use_5hz_lm_btn": use_5hz_lm_btn,
512
  "lm_temperature": lm_temperature,
 
 
513
  "repainting_group": repainting_group,
514
  "repainting_start": repainting_start,
515
  "repainting_end": repainting_end,
@@ -666,11 +696,12 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
666
  )
667
 
668
  # Service initialization
669
- def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
670
  """Wrapper for service initialization, returns status and button state"""
671
  status, enable = handler.initialize_service(
672
  checkpoint, config_path, device, init_llm, lm_model_path,
673
- use_flash_attention, compile_model=False,
 
674
  offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
675
  )
676
  return status, gr.update(interactive=enable)
@@ -683,6 +714,7 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
683
  generation_section["device"],
684
  generation_section["init_llm_checkbox"],
685
  generation_section["lm_model_path"],
 
686
  generation_section["use_flash_attention_checkbox"],
687
  generation_section["offload_to_cpu_checkbox"],
688
  generation_section["offload_dit_to_cpu_checkbox"],
@@ -690,6 +722,30 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
690
  outputs=[generation_section["init_status"], generation_section["generate_btn"]]
691
  )
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  # Generation with progress bar
694
  def generate_with_progress(
695
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
@@ -762,9 +818,9 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
762
  )
763
 
764
  # 5Hz LM generation (simplified version, can be extended as needed)
765
- def generate_lm_hints_wrapper(caption, lyrics, temperature):
766
  """Wrapper for 5Hz LM generation"""
767
- metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
768
 
769
  # Extract metadata values and map to UI fields
770
  # Handle bpm
@@ -801,7 +857,9 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
801
  inputs=[
802
  generation_section["captions"],
803
  generation_section["lyrics"],
804
- generation_section["lm_temperature"]
 
 
805
  ],
806
  outputs=[
807
  generation_section["text2music_audio_code_string"],
 
160
 
161
  # Service Configuration
162
  with gr.Accordion("🔧 Service Configuration", open=True) as service_config_accordion:
163
+ # Dropdown options section - all dropdowns grouped together
164
+ with gr.Row(equal_height=True):
165
+ with gr.Column(scale=4):
166
  checkpoint_dropdown = gr.Dropdown(
167
  label="Checkpoint File",
168
  choices=handler.get_available_checkpoints(),
169
  value=None,
170
  info="Select a trained model checkpoint file (full path or filename)"
171
  )
172
+ with gr.Column(scale=1, min_width=90):
173
  refresh_btn = gr.Button("🔄 Refresh", size="sm")
174
+
175
  with gr.Row():
176
  # Get available acestep-v15- model list
177
  available_models = handler.get_available_acestep_v15_models()
 
201
  value=default_lm_model,
202
  info="Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)"
203
  )
204
+ backend_dropdown = gr.Dropdown(
205
+ choices=["vllm", "pt"],
206
+ value="vllm",
207
+ label="5Hz LM Backend",
208
+ info="Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)"
209
+ )
210
+
211
+ # Checkbox options section - all checkboxes grouped together
212
+ with gr.Row():
213
  init_llm_checkbox = gr.Checkbox(
214
  label="Initialize 5Hz LM",
215
  value=False,
216
  info="Check to initialize 5Hz LM during service initialization",
217
  )
 
 
218
  # Auto-detect flash attention availability
219
  flash_attn_available = handler.is_flash_attention_available()
220
  use_flash_attention_checkbox = gr.Checkbox(
 
231
  offload_dit_to_cpu_checkbox = gr.Checkbox(
232
  label="Offload DiT to CPU",
233
  value=False,
234
+ info="Offload DiT to CPU (needs Offload to CPU)"
235
  )
236
 
237
  init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
 
327
  maximum=2.0,
328
  value=0.7,
329
  step=0.1,
330
+ scale=1,
331
+ info="Temperature for 5Hz LM sampling (higher = more random, lower = more deterministic)"
332
+ )
333
+ lm_cfg_scale = gr.Slider(
334
+ label="CFG Scale",
335
+ minimum=1.0,
336
+ maximum=3.0,
337
+ value=1.0,
338
+ step=0.1,
339
+ scale=1,
340
+ info="Classifier-Free Guidance scale for 5Hz LM (1.0 = no CFG, higher = stronger guidance)"
341
  )
342
 
343
+ # Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
344
+ lm_negative_prompt = gr.Textbox(
345
+ label="Negative Prompt",
346
+ value="NO USER INPUT",
347
+ placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
348
+ visible=False,
349
+ info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
350
+ lines=2
351
+ )
352
+
353
  # Repainting controls
354
  with gr.Group(visible=False) as repainting_group:
355
  gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
 
522
  "init_status": init_status,
523
  "lm_model_path": lm_model_path,
524
  "init_llm_checkbox": init_llm_checkbox,
525
+ "backend_dropdown": backend_dropdown,
526
  "use_flash_attention_checkbox": use_flash_attention_checkbox,
527
  "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
528
  "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
 
538
  "use_5hz_lm_row": use_5hz_lm_row,
539
  "use_5hz_lm_btn": use_5hz_lm_btn,
540
  "lm_temperature": lm_temperature,
541
+ "lm_cfg_scale": lm_cfg_scale,
542
+ "lm_negative_prompt": lm_negative_prompt,
543
  "repainting_group": repainting_group,
544
  "repainting_start": repainting_start,
545
  "repainting_end": repainting_end,
 
696
  )
697
 
698
  # Service initialization
699
+ def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
700
  """Wrapper for service initialization, returns status and button state"""
701
  status, enable = handler.initialize_service(
702
  checkpoint, config_path, device, init_llm, lm_model_path,
703
+ backend=backend,
704
+ use_flash_attention=use_flash_attention, compile_model=False,
705
  offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
706
  )
707
  return status, gr.update(interactive=enable)
 
714
  generation_section["device"],
715
  generation_section["init_llm_checkbox"],
716
  generation_section["lm_model_path"],
717
+ generation_section["backend_dropdown"],
718
  generation_section["use_flash_attention_checkbox"],
719
  generation_section["offload_to_cpu_checkbox"],
720
  generation_section["offload_dit_to_cpu_checkbox"],
 
722
  outputs=[generation_section["init_status"], generation_section["generate_btn"]]
723
  )
724
 
725
+ # Update negative prompt visibility based on LM initialization and CFG scale
726
+ def update_negative_prompt_visibility(init_status, cfg_scale):
727
+ """Update negative prompt visibility: show only if LM initialized and cfg_scale > 1"""
728
+ # Check if LM is initialized by looking for "5Hz LM backend:" in status
729
+ lm_initialized = init_status is not None and "5Hz LM backend:" in str(init_status)
730
+ # Check if cfg_scale > 1
731
+ cfg_enabled = cfg_scale is not None and float(cfg_scale) > 1.0
732
+ # Show only if both conditions are met
733
+ return gr.update(visible=lm_initialized and cfg_enabled)
734
+
735
+ # Update visibility when init_status changes
736
+ generation_section["init_status"].change(
737
+ fn=update_negative_prompt_visibility,
738
+ inputs=[generation_section["init_status"], generation_section["lm_cfg_scale"]],
739
+ outputs=[generation_section["lm_negative_prompt"]]
740
+ )
741
+
742
+ # Update visibility when cfg_scale changes
743
+ generation_section["lm_cfg_scale"].change(
744
+ fn=update_negative_prompt_visibility,
745
+ inputs=[generation_section["init_status"], generation_section["lm_cfg_scale"]],
746
+ outputs=[generation_section["lm_negative_prompt"]]
747
+ )
748
+
749
  # Generation with progress bar
750
  def generate_with_progress(
751
  captions, lyrics, bpm, key_scale, time_signature, vocal_language,
 
818
  )
819
 
820
  # 5Hz LM generation (simplified version, can be extended as needed)
821
+ def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, negative_prompt):
822
  """Wrapper for 5Hz LM generation"""
823
+ metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature, cfg_scale, negative_prompt)
824
 
825
  # Extract metadata values and map to UI fields
826
  # Handle bpm
 
857
  inputs=[
858
  generation_section["captions"],
859
  generation_section["lyrics"],
860
+ generation_section["lm_temperature"],
861
+ generation_section["lm_cfg_scale"],
862
+ generation_section["lm_negative_prompt"]
863
  ],
864
  outputs=[
865
  generation_section["text2music_audio_code_string"],
acestep/handler.py CHANGED
@@ -151,6 +151,7 @@ class AceStepHandler:
151
  device: str = "auto",
152
  init_llm: bool = False,
153
  lm_model_path: str = "acestep-5Hz-lm-0.6B",
 
154
  use_flash_attention: bool = False,
155
  compile_model: bool = False,
156
  offload_to_cpu: bool = False,
@@ -165,6 +166,7 @@ class AceStepHandler:
165
  device: Device type
166
  init_llm: Whether to initialize 5Hz LM model
167
  lm_model_path: 5Hz LM model path
 
168
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
169
  compile_model: Whether to use torch.compile to optimize the model
170
  offload_to_cpu: Whether to offload models to CPU when not in use
@@ -285,20 +287,20 @@ class AceStepHandler:
285
  if os.path.exists(full_lm_model_path):
286
  logger.info("loading 5Hz LM tokenizer...")
287
  start_time = time.time()
288
- llm_tokenizer = deepcopy(self.text_tokenizer)
289
- max_audio_length = 2**16 - 1
290
- semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
291
- # 217204
292
- llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
293
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
294
  self.llm_tokenizer = llm_tokenizer
295
- if device == "cuda":
 
 
 
296
  status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path)
297
  logger.info(f"5Hz LM status message: {status_msg}")
298
  # Check if initialization failed (status_msg starts with ❌)
299
  if status_msg.startswith("❌"):
300
  # vllm initialization failed, fallback to PyTorch
301
  if not self.llm_initialized:
 
302
  try:
303
  self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
304
  if not self.offload_to_cpu:
@@ -308,15 +310,14 @@ class AceStepHandler:
308
  self.llm.eval()
309
  self.llm_backend = "pt"
310
  self.llm_initialized = True
311
- logger.info("5Hz LM initialized successfully on CUDA device using Transformers backend")
312
  except Exception as e:
313
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
314
  # If vllm initialization succeeded, self.llm_initialized should already be True
315
  else:
316
- # For CPU or other devices, use PyTorch backend
317
  try:
318
  self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
319
- self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True, trust_remote_code=True)
320
  if not self.offload_to_cpu:
321
  self.llm = self.llm.to(device).to(self.dtype)
322
  else:
@@ -324,7 +325,7 @@ class AceStepHandler:
324
  self.llm.eval()
325
  self.llm_backend = "pt"
326
  self.llm_initialized = True
327
- logger.info("5Hz LM initialized successfully on non-CUDA device using Transformers backend")
328
  except Exception as e:
329
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
330
 
@@ -340,7 +341,9 @@ class AceStepHandler:
340
  status_msg += f"VAE: {vae_checkpoint_path}\n"
341
  status_msg += f"Text encoder: {text_encoder_path}\n"
342
  if init_llm and hasattr(self, 'llm') and self.llm is not None:
 
343
  status_msg += f"5Hz LM model: {os.path.join(checkpoint_dir, lm_model_path)}\n"
 
344
  else:
345
  status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
346
  status_msg += f"Dtype: {self.dtype}\n"
@@ -494,9 +497,9 @@ class AceStepHandler:
494
  max_ratio=0.9
495
  )
496
  if low_gpu_memory_mode:
497
- self.max_model_len = 1024
498
- else:
499
  self.max_model_len = 2048
 
 
500
 
501
  logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
502
  start_time = time.time()
@@ -506,9 +509,9 @@ class AceStepHandler:
506
  tensor_parallel_size=1,
507
  max_model_len=self.max_model_len,
508
  gpu_memory_utilization=gpu_memory_utilization,
 
509
  )
510
  logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
511
- self.llm.tokenizer = self.llm_tokenizer
512
  self.llm_initialized = True
513
  self.llm_backend = "vllm"
514
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
@@ -518,7 +521,7 @@ class AceStepHandler:
518
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
519
  return error_msg
520
 
521
- def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
522
  try:
523
  from nanovllm import SamplingParams
524
 
@@ -534,35 +537,41 @@ class AceStepHandler:
534
  )
535
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
536
 
537
- sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
538
- outputs = self.llm.generate([formatted_prompt], sampling_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  if isinstance(outputs, list) and len(outputs) > 0:
540
  if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
541
  output_text = outputs[0].outputs[0].text
542
  elif hasattr(outputs[0], 'text'):
543
  output_text = outputs[0].text
 
 
544
  else:
545
- # Transformers generation
546
- inputs = self.llm_tokenizer(formatted_prompt, return_tensors="pt").to(self.llm.device)
547
-
548
- # Generate
549
- with torch.no_grad():
550
- outputs = self.llm.generate(
551
- **inputs,
552
- max_new_tokens=3072,
553
- temperature=temperature,
554
- do_sample=True,
555
- pad_token_id=self.llm_tokenizer.pad_token_id,
556
- eos_token_id=self.llm_tokenizer.eos_token_id
557
- )
558
-
559
- # Decode
560
- generated_ids = outputs[0][inputs.input_ids.shape[1]:]
561
- output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
562
-
563
- metadata, audio_codes = self.parse_lm_output(output_text)
564
- codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
565
- return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
566
 
567
  except Exception as e:
568
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
@@ -639,7 +648,7 @@ class AceStepHandler:
639
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
640
  return {}, "", error_msg
641
 
642
- def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
643
  """Generate metadata and audio codes using 5Hz LM"""
644
  # Check if 5Hz LM is initialized
645
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
@@ -656,7 +665,7 @@ class AceStepHandler:
656
  return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
657
 
658
  if self.llm_backend == "vllm":
659
- return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature)
660
  else:
661
  return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
662
 
 
151
  device: str = "auto",
152
  init_llm: bool = False,
153
  lm_model_path: str = "acestep-5Hz-lm-0.6B",
154
+ backend: str = "vllm",
155
  use_flash_attention: bool = False,
156
  compile_model: bool = False,
157
  offload_to_cpu: bool = False,
 
166
  device: Device type
167
  init_llm: Whether to initialize 5Hz LM model
168
  lm_model_path: 5Hz LM model path
169
+ backend: Backend for 5Hz LM model ("vllm" or "pt")
170
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
171
  compile_model: Whether to use torch.compile to optimize the model
172
  offload_to_cpu: Whether to offload models to CPU when not in use
 
287
  if os.path.exists(full_lm_model_path):
288
  logger.info("loading 5Hz LM tokenizer...")
289
  start_time = time.time()
290
+ llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
 
 
 
 
291
  logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
292
  self.llm_tokenizer = llm_tokenizer
293
+
294
+ # Initialize based on user-selected backend
295
+ if backend == "vllm":
296
+ # Try to initialize with vllm
297
  status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path)
298
  logger.info(f"5Hz LM status message: {status_msg}")
299
  # Check if initialization failed (status_msg starts with ❌)
300
  if status_msg.startswith("❌"):
301
  # vllm initialization failed, fallback to PyTorch
302
  if not self.llm_initialized:
303
+ logger.warning("vllm initialization failed, falling back to PyTorch backend")
304
  try:
305
  self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
306
  if not self.offload_to_cpu:
 
310
  self.llm.eval()
311
  self.llm_backend = "pt"
312
  self.llm_initialized = True
313
+ logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
314
  except Exception as e:
315
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
316
  # If vllm initialization succeeded, self.llm_initialized should already be True
317
  else:
318
+ # Use PyTorch backend (pt)
319
  try:
320
  self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
 
321
  if not self.offload_to_cpu:
322
  self.llm = self.llm.to(device).to(self.dtype)
323
  else:
 
325
  self.llm.eval()
326
  self.llm_backend = "pt"
327
  self.llm_initialized = True
328
+ logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
329
  except Exception as e:
330
  return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
331
 
 
341
  status_msg += f"VAE: {vae_checkpoint_path}\n"
342
  status_msg += f"Text encoder: {text_encoder_path}\n"
343
  if init_llm and hasattr(self, 'llm') and self.llm is not None:
344
+ backend_info = getattr(self, 'llm_backend', 'unknown')
345
  status_msg += f"5Hz LM model: {os.path.join(checkpoint_dir, lm_model_path)}\n"
346
+ status_msg += f"5Hz LM backend: {backend_info}\n"
347
  else:
348
  status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
349
  status_msg += f"Dtype: {self.dtype}\n"
 
497
  max_ratio=0.9
498
  )
499
  if low_gpu_memory_mode:
 
 
500
  self.max_model_len = 2048
501
+ else:
502
+ self.max_model_len = 4096
503
 
504
  logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
505
  start_time = time.time()
 
509
  tensor_parallel_size=1,
510
  max_model_len=self.max_model_len,
511
  gpu_memory_utilization=gpu_memory_utilization,
512
+ tokenizer=self.llm_tokenizer,
513
  )
514
  logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
 
515
  self.llm_initialized = True
516
  self.llm_backend = "vllm"
517
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
 
521
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
522
  return error_msg
523
 
524
+ def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
525
  try:
526
  from nanovllm import SamplingParams
527
 
 
537
  )
538
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
539
 
540
+ sampling_params = SamplingParams(max_tokens=self.max_model_len-64, temperature=temperature, cfg_scale=cfg_scale)
541
+ # Use CFG if cfg_scale > 1.0
542
+ if cfg_scale > 1.0:
543
+ # Build unconditional prompt (user input replaced with "NO USER INPUT")
544
+ formatted_unconditional_prompt = self.lm_tokenizer.apply_chat_template(
545
+ [
546
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
547
+ {"role": "user", "content": negative_prompt}
548
+ ],
549
+ tokenize=False,
550
+ add_generation_prompt=True,
551
+ )
552
+ outputs = self.llm.generate(
553
+ [formatted_prompt],
554
+ sampling_params,
555
+ unconditional_prompts=[formatted_unconditional_prompt]
556
+ )
557
+ else:
558
+ outputs = self.lm_model.generate([formatted_prompt], sampling_params)
559
+ # Extract text from output - handle different output formats
560
  if isinstance(outputs, list) and len(outputs) > 0:
561
  if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
562
  output_text = outputs[0].outputs[0].text
563
  elif hasattr(outputs[0], 'text'):
564
  output_text = outputs[0].text
565
+ elif isinstance(outputs[0], dict) and 'text' in outputs[0]:
566
+ output_text = outputs[0]['text']
567
  else:
568
+ output_text = str(outputs[0])
569
+ else:
570
+ output_text = str(outputs)
571
+ metadata, audio_codes = self.parse_lm_output(output_text)
572
+ print(f"[debug]output_text: {output_text}")
573
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
574
+ return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
  except Exception as e:
577
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
648
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
649
  return {}, "", error_msg
650
 
651
+ def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
652
  """Generate metadata and audio codes using 5Hz LM"""
653
  # Check if 5Hz LM is initialized
654
  if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
 
665
  return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
666
 
667
  if self.llm_backend == "vllm":
668
+ return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature, cfg_scale, negative_prompt)
669
  else:
670
  return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
671
 
acestep/third_parts/nano-vllm/nanovllm/config.py CHANGED
@@ -1,35 +1,8 @@
1
  import os
2
- import socket
3
  from dataclasses import dataclass
4
  from transformers import AutoConfig
5
 
6
 
7
- def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
8
- """Find an available port starting from start_port.
9
-
10
- Args:
11
- start_port: The starting port number to check
12
- max_attempts: Maximum number of ports to try
13
-
14
- Returns:
15
- An available port number
16
-
17
- Raises:
18
- RuntimeError: If no available port is found within max_attempts
19
- """
20
- for i in range(max_attempts):
21
- port = start_port + i
22
- try:
23
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
24
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
25
- s.bind(('localhost', port))
26
- return port
27
- except OSError:
28
- # Port is in use, try next one
29
- continue
30
- raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
31
-
32
-
33
  @dataclass
34
  class Config:
35
  model: str
@@ -40,10 +13,9 @@ class Config:
40
  tensor_parallel_size: int = 1
41
  enforce_eager: bool = False
42
  hf_config: AutoConfig | None = None
43
- eos: int = 151643
44
  kvcache_block_size: int = 256
45
  num_kvcache_blocks: int = -1
46
- dist_port: int | None = None
47
 
48
  def __post_init__(self):
49
  assert os.path.isdir(self.model)
@@ -52,6 +24,3 @@ class Config:
52
  self.hf_config = AutoConfig.from_pretrained(self.model)
53
  self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
54
  assert self.max_num_batched_tokens >= self.max_model_len
55
- # Auto-find available port if not specified
56
- if self.dist_port is None:
57
- self.dist_port = find_available_port()
 
1
  import os
 
2
  from dataclasses import dataclass
3
  from transformers import AutoConfig
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @dataclass
7
  class Config:
8
  model: str
 
13
  tensor_parallel_size: int = 1
14
  enforce_eager: bool = False
15
  hf_config: AutoConfig | None = None
16
+ eos: int = -1
17
  kvcache_block_size: int = 256
18
  num_kvcache_blocks: int = -1
 
19
 
20
  def __post_init__(self):
21
  assert os.path.isdir(self.model)
 
24
  self.hf_config = AutoConfig.from_pretrained(self.model)
25
  self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
26
  assert self.max_num_batched_tokens >= self.max_model_len
 
 
 
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py CHANGED
@@ -21,28 +21,6 @@ class LLMEngine:
21
  self.ps = []
22
  self.events = []
23
  ctx = mp.get_context("spawn")
24
-
25
- # Pre-validate port availability by attempting to bind to it
26
- # This helps avoid race conditions when multiple LLMEngine instances start simultaneously
27
- import socket
28
- from nanovllm.config import find_available_port
29
- max_port_retries = 10
30
- for port_attempt in range(max_port_retries):
31
- try:
32
- # Test if port is actually available by binding to it
33
- test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
34
- test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
35
- test_socket.bind(('localhost', config.dist_port))
36
- test_socket.close()
37
- # Port is available, break
38
- break
39
- except OSError:
40
- # Port is in use, find next available
41
- if port_attempt < max_port_retries - 1:
42
- config.dist_port = find_available_port(start_port=config.dist_port + 1, max_attempts=10)
43
- else:
44
- raise RuntimeError(f"Failed to find available port after {max_port_retries} attempts")
45
-
46
  for i in range(1, config.tensor_parallel_size):
47
  event = ctx.Event()
48
  process = ctx.Process(target=ModelRunner, args=(config, i, event))
@@ -50,7 +28,12 @@ class LLMEngine:
50
  self.ps.append(process)
51
  self.events.append(event)
52
  self.model_runner = ModelRunner(config, 0, self.events)
53
- self.tokenizer = None
 
 
 
 
 
54
  self.scheduler = Scheduler(config)
55
  atexit.register(self.exit)
56
 
 
21
  self.ps = []
22
  self.events = []
23
  ctx = mp.get_context("spawn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  for i in range(1, config.tensor_parallel_size):
25
  event = ctx.Event()
26
  process = ctx.Process(target=ModelRunner, args=(config, i, event))
 
28
  self.ps.append(process)
29
  self.events.append(event)
30
  self.model_runner = ModelRunner(config, 0, self.events)
31
+ tokenizer = kwargs.get("tokenizer", None)
32
+ if tokenizer is not None:
33
+ self.tokenizer = tokenizer
34
+ else:
35
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
36
+ config.eos = self.tokenizer.eos_token_id
37
  self.scheduler = Scheduler(config)
38
  atexit.register(self.exit)
39
 
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -1,17 +1,44 @@
1
  import pickle
2
- import socket
3
  import torch
4
  import torch.distributed as dist
5
  from multiprocessing.synchronize import Event
6
  from multiprocessing.shared_memory import SharedMemory
7
 
8
- from nanovllm.config import Config, find_available_port
9
  from nanovllm.engine.sequence import Sequence
10
  from nanovllm.models.qwen3 import Qwen3ForCausalLM
11
  from nanovllm.layers.sampler import Sampler
12
  from nanovllm.utils.context import set_context, get_context, reset_context
13
  from nanovllm.utils.loader import load_model
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class ModelRunner:
17
 
@@ -23,33 +50,9 @@ class ModelRunner:
23
  self.world_size = config.tensor_parallel_size
24
  self.rank = rank
25
  self.event = event
26
-
27
- # Try to initialize process group with retry logic for port conflicts
28
- # Only rank 0 binds to the port, so only rank 0 needs retry logic
29
- dist_port = self.config.dist_port
30
- max_retries = 10
31
- for attempt in range(max_retries):
32
- try:
33
- dist.init_process_group("nccl", f"tcp://localhost:{dist_port}", world_size=self.world_size, rank=rank)
34
- break
35
- except RuntimeError as e:
36
- if ("EADDRINUSE" in str(e) or "address already in use" in str(e).lower()) and rank == 0:
37
- # Port is in use, try next port (only for rank 0)
38
- if attempt < max_retries - 1:
39
- # Find next available port
40
- dist_port = find_available_port(start_port=dist_port + 1, max_attempts=10)
41
- self.config.dist_port = dist_port
42
- # If we had a previous failed attempt, destroy any partial process group
43
- if dist.is_initialized():
44
- try:
45
- dist.destroy_process_group()
46
- except:
47
- pass
48
- else:
49
- raise RuntimeError(f"Failed to find available port after {max_retries} attempts. Last error: {e}")
50
- else:
51
- # Other error or non-rank-0 process, re-raise
52
- raise
53
  torch.cuda.set_device(rank)
54
  default_dtype = torch.get_default_dtype()
55
  torch.set_default_dtype(hf_config.torch_dtype)
@@ -144,15 +147,9 @@ class ModelRunner:
144
  layer_id += 1
145
 
146
  def prepare_block_tables(self, seqs: list[Sequence]):
147
- max_len = max(len(seq.block_table) for seq in seqs) if seqs else 0
148
- if max_len == 0:
149
- # Return empty 2D tensor with correct shape
150
- return torch.zeros((len(seqs), 0), dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
151
  block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
152
  block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
153
- # Ensure it's 2D: if only one sequence, shape should be [1, max_len]
154
- if block_tables.dim() == 1:
155
- block_tables = block_tables.unsqueeze(0)
156
  return block_tables
157
 
158
  def prepare_prefill(self, seqs: list[Sequence]):
@@ -247,29 +244,7 @@ class ModelRunner:
247
  graph_vars["slot_mapping"][:bs] = context.slot_mapping
248
  graph_vars["context_lens"].zero_()
249
  graph_vars["context_lens"][:bs] = context.context_lens
250
- # Handle block_tables: ensure it's 2D and size matches
251
- if context.block_tables is not None and context.block_tables.numel() > 0:
252
- # Ensure block_tables is 2D
253
- if context.block_tables.dim() == 1:
254
- # Reshape 1D to 2D: [num_blocks] -> [1, num_blocks]
255
- block_tables_2d = context.block_tables.unsqueeze(0)
256
- else:
257
- block_tables_2d = context.block_tables
258
-
259
- # Get dimensions
260
- context_bs = block_tables_2d.size(0)
261
- context_num_blocks = block_tables_2d.size(1)
262
- graph_num_blocks = graph_vars["block_tables"].size(1)
263
-
264
- # Use minimum to avoid size mismatch
265
- num_blocks_to_copy = min(context_num_blocks, graph_num_blocks)
266
- actual_bs = min(bs, context_bs)
267
-
268
- # Copy block_tables with size matching
269
- graph_vars["block_tables"][:actual_bs, :num_blocks_to_copy] = block_tables_2d[:actual_bs, :num_blocks_to_copy]
270
- # Fill remaining with -1 if needed
271
- if num_blocks_to_copy < graph_num_blocks:
272
- graph_vars["block_tables"][:actual_bs, num_blocks_to_copy:] = -1
273
  graph.replay()
274
  return self.model.compute_logits(graph_vars["outputs"][:bs])
275
 
 
1
  import pickle
 
2
  import torch
3
  import torch.distributed as dist
4
  from multiprocessing.synchronize import Event
5
  from multiprocessing.shared_memory import SharedMemory
6
 
7
+ from nanovllm.config import Config
8
  from nanovllm.engine.sequence import Sequence
9
  from nanovllm.models.qwen3 import Qwen3ForCausalLM
10
  from nanovllm.layers.sampler import Sampler
11
  from nanovllm.utils.context import set_context, get_context, reset_context
12
  from nanovllm.utils.loader import load_model
13
 
14
+ import socket
15
+
16
+
17
+ def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
18
+ """Find an available port starting from start_port.
19
+
20
+ Args:
21
+ start_port: The starting port number to check
22
+ max_attempts: Maximum number of ports to try
23
+
24
+ Returns:
25
+ An available port number
26
+
27
+ Raises:
28
+ RuntimeError: If no available port is found within max_attempts
29
+ """
30
+ for i in range(max_attempts):
31
+ port = start_port + i
32
+ try:
33
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
34
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
35
+ s.bind(('localhost', port))
36
+ return port
37
+ except OSError:
38
+ # Port is in use, try next one
39
+ continue
40
+ raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
41
+
42
 
43
  class ModelRunner:
44
 
 
50
  self.world_size = config.tensor_parallel_size
51
  self.rank = rank
52
  self.event = event
53
+ dist_port = find_available_port()
54
+ print(f"[debug]dist_port: {dist_port}")
55
+ dist.init_process_group("nccl", f"tcp://localhost:{dist_port}", world_size=self.world_size, rank=rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  torch.cuda.set_device(rank)
57
  default_dtype = torch.get_default_dtype()
58
  torch.set_default_dtype(hf_config.torch_dtype)
 
147
  layer_id += 1
148
 
149
  def prepare_block_tables(self, seqs: list[Sequence]):
150
+ max_len = max(len(seq.block_table) for seq in seqs)
 
 
 
151
  block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
152
  block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
 
 
 
153
  return block_tables
154
 
155
  def prepare_prefill(self, seqs: list[Sequence]):
 
244
  graph_vars["slot_mapping"][:bs] = context.slot_mapping
245
  graph_vars["context_lens"].zero_()
246
  graph_vars["context_lens"][:bs] = context.context_lens
247
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  graph.replay()
249
  return self.model.compute_logits(graph_vars["outputs"][:bs])
250