ChuxiJ commited on
Commit
25e1b40
·
1 Parent(s): 1e1527a

add 5hz llm

Browse files
acestep/gradio_ui.py CHANGED
@@ -204,7 +204,6 @@ def create_generation_section(handler) -> dict:
204
  label="Initialize 5Hz LM",
205
  value=False,
206
  info="Check to initialize 5Hz LM during service initialization",
207
- interactive=False
208
  )
209
 
210
  with gr.Row():
@@ -298,7 +297,7 @@ def create_generation_section(handler) -> dict:
298
  )
299
 
300
  # 5Hz LM
301
- with gr.Row(visible=False) as use_5hz_lm_row:
302
  use_5hz_lm_btn = gr.Button(
303
  "Generate LM Hints",
304
  variant="secondary",
@@ -748,9 +747,36 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
748
  def generate_lm_hints_wrapper(caption, lyrics, temperature):
749
  """Wrapper for 5Hz LM generation"""
750
  metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
751
- # 返回格式化的结果,可以根据需要调整
752
- result_text = f"Status: {status}\n\nMetadata: {metadata}\n\nAudio Codes: {audio_codes[:200]}..." if len(audio_codes) > 200 else f"Status: {status}\n\nMetadata: {metadata}\n\nAudio Codes: {audio_codes}"
753
- return result_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
  generation_section["use_5hz_lm_btn"].click(
756
  fn=generate_lm_hints_wrapper,
@@ -759,7 +785,13 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
759
  generation_section["lyrics"],
760
  generation_section["lm_temperature"]
761
  ],
762
- outputs=[generation_section["text2music_audio_code_string"]]
 
 
 
 
 
 
763
  )
764
 
765
  # Update instruction and UI visibility based on task type
 
204
  label="Initialize 5Hz LM",
205
  value=False,
206
  info="Check to initialize 5Hz LM during service initialization",
 
207
  )
208
 
209
  with gr.Row():
 
297
  )
298
 
299
  # 5Hz LM
300
+ with gr.Row(visible=True) as use_5hz_lm_row:
301
  use_5hz_lm_btn = gr.Button(
302
  "Generate LM Hints",
303
  variant="secondary",
 
747
  def generate_lm_hints_wrapper(caption, lyrics, temperature):
748
  """Wrapper for 5Hz LM generation"""
749
  metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature)
750
+
751
+ # Extract metadata values and map to UI fields
752
+ # Handle bpm
753
+ bpm_value = metadata.get('bpm', None)
754
+ if bpm_value == "N/A" or bpm_value == "":
755
+ bpm_value = None
756
+
757
+ # Handle key_scale (metadata uses 'keyscale')
758
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
759
+ if key_scale_value == "N/A":
760
+ key_scale_value = ""
761
+
762
+ # Handle time_signature (metadata uses 'timesignature')
763
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
764
+ if time_signature_value == "N/A":
765
+ time_signature_value = ""
766
+
767
+ # Handle audio_duration (metadata uses 'duration')
768
+ audio_duration_value = metadata.get('duration', -1)
769
+ if audio_duration_value == "N/A" or audio_duration_value == "":
770
+ audio_duration_value = -1
771
+
772
+ # Return audio codes and all metadata fields
773
+ return (
774
+ audio_codes, # text2music_audio_code_string
775
+ bpm_value, # bpm
776
+ key_scale_value, # key_scale
777
+ time_signature_value, # time_signature
778
+ audio_duration_value, # audio_duration
779
+ )
780
 
781
  generation_section["use_5hz_lm_btn"].click(
782
  fn=generate_lm_hints_wrapper,
 
785
  generation_section["lyrics"],
786
  generation_section["lm_temperature"]
787
  ],
788
+ outputs=[
789
+ generation_section["text2music_audio_code_string"],
790
+ generation_section["bpm"],
791
+ generation_section["key_scale"],
792
+ generation_section["time_signature"],
793
+ generation_section["audio_duration"],
794
+ ]
795
  )
796
 
797
  # Update instruction and UI visibility based on task type
acestep/handler.py CHANGED
@@ -4,6 +4,7 @@ Encapsulates all data processing and business logic as a bridge between model an
4
  """
5
  import os
6
  import math
 
7
  import tempfile
8
  import traceback
9
  import re
@@ -58,9 +59,10 @@ class AceStepHandler:
58
  self.sample_rate = 48000
59
 
60
  # 5Hz LM related
61
- self.lm_model = None
62
- self.lm_tokenizer = None
63
- self.lm_initialized = False
 
64
 
65
  # Reward model (temporarily disabled)
66
  self.reward_model = None
@@ -218,12 +220,43 @@ class AceStepHandler:
218
  if init_llm:
219
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
220
  if os.path.exists(full_lm_model_path):
 
 
 
 
 
 
 
 
 
221
  if device == "cuda":
222
  status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
223
- if not self.llm_initialized:
224
- return status_msg, False
225
- self.llm = AutoModel.from_pretrained(full_lm_model_path)
226
- self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  else:
228
  # 5Hz LM path not found
229
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
@@ -266,6 +299,10 @@ class AceStepHandler:
266
  reserved_mem_bytes = torch.cuda.memory_reserved(device)
267
 
268
  total_gpu = total_gpu_mem_bytes / 1024**3
 
 
 
 
269
  allocated_gpu = allocated_mem_bytes / 1024**3
270
  reserved_gpu = reserved_mem_bytes / 1024**3
271
  available_gpu = total_gpu - reserved_gpu
@@ -275,54 +312,64 @@ class AceStepHandler:
275
  else:
276
  ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
277
 
278
- return ratio
279
  except Exception as e:
280
- return 0.9
281
 
282
  def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
283
  """Initialize 5Hz LM model"""
 
 
 
 
284
  try:
285
  from nanovllm import LLM, SamplingParams
286
-
287
- if not torch.cuda.is_available():
288
- return " CUDA is not available. Please check your GPU setup."
289
-
 
 
290
  current_device = torch.cuda.current_device()
291
  device_name = torch.cuda.get_device_name(current_device)
292
 
293
  torch.cuda.empty_cache()
294
- gpu_memory_utilization = self.get_gpu_memory_utilization(
295
  minimal_gpu=8,
296
  min_ratio=0.2,
297
  max_ratio=0.9
298
  )
 
 
 
 
299
 
 
 
300
  self.llm = LLM(
301
  model=model_path,
302
  enforce_eager=False,
303
  tensor_parallel_size=1,
304
- max_model_len=4096,
305
  gpu_memory_utilization=gpu_memory_utilization,
306
  )
307
- self.llm_tokenizer = self.llm.tokenizer
 
308
  self.llm_initialized = True
 
309
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
310
  except Exception as e:
311
  self.llm_initialized = False
312
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
313
  return error_msg
314
-
315
- def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
316
- """Generate metadata and audio codes using 5Hz LM"""
317
- if not self.lm_initialized or self.llm is None:
318
- return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
319
-
320
  try:
321
  from nanovllm import SamplingParams
322
 
323
  prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
324
 
325
- formatted_prompt = self.lm_tokenizer.apply_chat_template(
326
  [
327
  {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
328
  {"role": "user", "content": prompt}
@@ -330,10 +377,10 @@ class AceStepHandler:
330
  tokenize=False,
331
  add_generation_prompt=True,
332
  )
 
333
 
334
- sampling_params = SamplingParams(max_tokens=3072, temperature=temperature)
335
  outputs = self.llm.generate([formatted_prompt], sampling_params)
336
-
337
  if isinstance(outputs, list) and len(outputs) > 0:
338
  if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
339
  output_text = outputs[0].outputs[0].text
@@ -351,22 +398,113 @@ class AceStepHandler:
351
  except Exception as e:
352
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
353
  return {}, "", error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
356
- """Parse LM output"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  metadata = {}
358
  audio_codes = ""
359
 
360
  import re
361
 
362
- # Extract audio codes
363
  code_pattern = r'<\|audio_code_\d+\|>'
364
  code_matches = re.findall(code_pattern, output_text)
365
  if code_matches:
366
  audio_codes = "".join(code_matches)
367
 
368
- # Extract metadata
 
369
  reasoning_patterns = [
 
370
  r'<think>(.*?)</think>',
371
  r'<reasoning>(.*?)</reasoning>',
372
  ]
@@ -378,7 +516,9 @@ class AceStepHandler:
378
  reasoning_text = match.group(1).strip()
379
  break
380
 
 
381
  if not reasoning_text:
 
382
  lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
383
  reasoning_text = lines_before_codes.strip()
384
 
@@ -402,8 +542,12 @@ class AceStepHandler:
402
  metadata['duration'] = int(value)
403
  except:
404
  metadata['duration'] = value
405
- elif key in ['genres', 'keyscale', 'timesignature']:
406
- metadata[key] = value
 
 
 
 
407
 
408
  return metadata, audio_codes
409
 
 
4
  """
5
  import os
6
  import math
7
+ from copy import deepcopy
8
  import tempfile
9
  import traceback
10
  import re
 
59
  self.sample_rate = 48000
60
 
61
  # 5Hz LM related
62
+ self.llm = None
63
+ self.llm_tokenizer = None
64
+ self.llm_initialized = False
65
+ self.llm_backend = None
66
 
67
  # Reward model (temporarily disabled)
68
  self.reward_model = None
 
220
  if init_llm:
221
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
222
  if os.path.exists(full_lm_model_path):
223
+ print("loading 5Hz LM tokenizer...")
224
+ start_time = time.time()
225
+ llm_tokenizer = deepcopy(self.text_tokenizer)
226
+ max_audio_length = 2**16 - 1
227
+ semantic_tokens = [f"<|audio_code_{i}|>" for i in range(max_audio_length)]
228
+ # 217204
229
+ llm_tokenizer.add_special_tokens({"additional_special_tokens": semantic_tokens})
230
+ print(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
231
+ self.llm_tokenizer = llm_tokenizer
232
  if device == "cuda":
233
  status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
234
+ print(f"5Hz LM status message: {status_msg}")
235
+ # Check if initialization failed (status_msg starts with ❌)
236
+ if status_msg.startswith("❌"):
237
+ # vllm initialization failed, fallback to PyTorch
238
+ if not self.llm_initialized:
239
+ try:
240
+ self.llm = AutoModel.from_pretrained(full_lm_model_path)
241
+ self.llm = self.llm.to(device).to(self.dtype)
242
+ self.llm.eval()
243
+ self.llm_backend = "pt"
244
+ self.llm_initialized = True
245
+ except Exception as e:
246
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
247
+ # If vllm initialization succeeded, self.llm_initialized should already be True
248
+ else:
249
+ # For CPU or other devices, use PyTorch backend
250
+ try:
251
+ self.llm = AutoModel.from_pretrained(full_lm_model_path)
252
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
253
+ self.llm = self.llm.to(device).to(self.dtype)
254
+ self.llm.eval()
255
+ self.llm_backend = "pt"
256
+ self.llm_initialized = True
257
+ except Exception as e:
258
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
259
+
260
  else:
261
  # 5Hz LM path not found
262
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
 
299
  reserved_mem_bytes = torch.cuda.memory_reserved(device)
300
 
301
  total_gpu = total_gpu_mem_bytes / 1024**3
302
+ low_gpu_memory_mode = False
303
+ if total_gpu < minimal_gpu:
304
+ minimal_gpu = 0.5 * total_gpu
305
+ low_gpu_memory_mode = True
306
  allocated_gpu = allocated_mem_bytes / 1024**3
307
  reserved_gpu = reserved_mem_bytes / 1024**3
308
  available_gpu = total_gpu - reserved_gpu
 
312
  else:
313
  ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
314
 
315
+ return ratio, low_gpu_memory_mode
316
  except Exception as e:
317
+ return 0.9, low_gpu_memory_mode
318
 
319
  def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
320
  """Initialize 5Hz LM model"""
321
+ if not torch.cuda.is_available():
322
+ self.llm_initialized = False
323
+ print("CUDA is not available. Please check your GPU setup.")
324
+ return "❌ CUDA is not available. Please check your GPU setup."
325
  try:
326
  from nanovllm import LLM, SamplingParams
327
+ except ImportError:
328
+ self.llm_initialized = False
329
+ print("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
330
+ return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
331
+
332
+ try:
333
  current_device = torch.cuda.current_device()
334
  device_name = torch.cuda.get_device_name(current_device)
335
 
336
  torch.cuda.empty_cache()
337
+ gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization(
338
  minimal_gpu=8,
339
  min_ratio=0.2,
340
  max_ratio=0.9
341
  )
342
+ if low_gpu_memory_mode:
343
+ self.max_model_len = 1024
344
+ else:
345
+ self.max_model_len = 2048
346
 
347
+ print(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
348
+ start_time = time.time()
349
  self.llm = LLM(
350
  model=model_path,
351
  enforce_eager=False,
352
  tensor_parallel_size=1,
353
+ max_model_len=self.max_model_len,
354
  gpu_memory_utilization=gpu_memory_utilization,
355
  )
356
+ print(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
357
+ self.llm.tokenizer = self.llm_tokenizer
358
  self.llm_initialized = True
359
+ self.llm_backend = "vllm"
360
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
361
  except Exception as e:
362
  self.llm_initialized = False
363
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
364
  return error_msg
365
+
366
+ def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
 
 
 
 
367
  try:
368
  from nanovllm import SamplingParams
369
 
370
  prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
371
 
372
+ formatted_prompt = self.llm_tokenizer.apply_chat_template(
373
  [
374
  {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
375
  {"role": "user", "content": prompt}
 
377
  tokenize=False,
378
  add_generation_prompt=True,
379
  )
380
+ print("[debug] formatted_prompt: ", formatted_prompt)
381
 
382
+ sampling_params = SamplingParams(max_tokens=self.max_model_len, temperature=temperature)
383
  outputs = self.llm.generate([formatted_prompt], sampling_params)
 
384
  if isinstance(outputs, list) and len(outputs) > 0:
385
  if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
386
  output_text = outputs[0].outputs[0].text
 
398
  except Exception as e:
399
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
400
  return {}, "", error_msg
401
+
402
+ def generate_with_5hz_lm_pt(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
403
+ try:
404
+ prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
405
+
406
+ formatted_prompt = self.llm_tokenizer.apply_chat_template(
407
+ [
408
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
409
+ {"role": "user", "content": prompt}
410
+ ],
411
+ tokenize=False,
412
+ add_generation_prompt=True,
413
+ )
414
+
415
+ # Tokenize the prompt
416
+ inputs = self.llm_tokenizer(
417
+ formatted_prompt,
418
+ return_tensors="pt",
419
+ padding=False,
420
+ truncation=True,
421
+ )
422
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
423
+
424
+ # Generate with the model
425
+ with torch.no_grad():
426
+ # Get max_new_tokens from model config or use a default
427
+ max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
428
+ if hasattr(self, 'max_model_len'):
429
+ max_new_tokens = min(max_new_tokens, self.max_model_len)
430
+
431
+ outputs = self.llm.generate(
432
+ **inputs,
433
+ max_new_tokens=max_new_tokens,
434
+ temperature=temperature,
435
+ do_sample=True if temperature > 0 else False,
436
+ pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
437
+ )
438
+
439
+ # Decode the generated tokens
440
+ # Only decode the newly generated tokens (skip the input prompt)
441
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
442
+ output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
443
+
444
+ metadata, audio_codes = self.parse_lm_output(output_text)
445
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
446
+ return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
447
+
448
+ except Exception as e:
449
+ error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
450
+ return {}, "", error_msg
451
+
452
+ def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
453
+ """Generate metadata and audio codes using 5Hz LM"""
454
+ # Check if 5Hz LM is initialized
455
+ if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
456
+ debug_info = f"llm_initialized={getattr(self, 'llm_initialized', 'not set')}, "
457
+ debug_info += f"has_llm={hasattr(self, 'llm')}, "
458
+ debug_info += f"llm_is_none={getattr(self, 'llm', None) is None}, "
459
+ debug_info += f"llm_backend={getattr(self, 'llm_backend', 'not set')}"
460
+ return {}, "", f"❌ 5Hz LM not initialized. Please initialize it first. Debug: {debug_info}"
461
+
462
+ if not hasattr(self, 'llm') or self.llm is None:
463
+ return {}, "", "❌ 5Hz LM model not loaded. Please initialize it first."
464
+
465
+ if not hasattr(self, 'llm_backend'):
466
+ return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
467
+
468
+ if self.llm_backend == "vllm":
469
+ return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature)
470
+ else:
471
+ return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
472
 
473
  def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
474
+ """
475
+ Parse LM output to extract metadata and audio codes.
476
+
477
+ Expected format:
478
+ <think>
479
+ bpm: 73
480
+ duration: 273
481
+ genres: Chinese folk
482
+ keyscale: G major
483
+ timesignature: 4
484
+ </think>
485
+
486
+ <|audio_code_56535|><|audio_code_62918|>...
487
+
488
+ Returns:
489
+ Tuple of (metadata_dict, audio_codes_string)
490
+ """
491
+ debug_output_text = output_text.split("</think>")[0]
492
+ print(f"Debug output text: {debug_output_text}")
493
  metadata = {}
494
  audio_codes = ""
495
 
496
  import re
497
 
498
+ # Extract audio codes - find all <|audio_code_XXX|> patterns
499
  code_pattern = r'<\|audio_code_\d+\|>'
500
  code_matches = re.findall(code_pattern, output_text)
501
  if code_matches:
502
  audio_codes = "".join(code_matches)
503
 
504
+ # Extract metadata from reasoning section
505
+ # Try different reasoning tag patterns
506
  reasoning_patterns = [
507
+ r'<think>(.*?)</think>',
508
  r'<think>(.*?)</think>',
509
  r'<reasoning>(.*?)</reasoning>',
510
  ]
 
516
  reasoning_text = match.group(1).strip()
517
  break
518
 
519
+ # If no reasoning tags found, try to parse metadata from the beginning of output
520
  if not reasoning_text:
521
+ # Look for metadata lines before audio codes
522
  lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
523
  reasoning_text = lines_before_codes.strip()
524
 
 
542
  metadata['duration'] = int(value)
543
  except:
544
  metadata['duration'] = value
545
+ elif key == 'genres':
546
+ metadata['genres'] = value
547
+ elif key == 'keyscale':
548
+ metadata['keyscale'] = value
549
+ elif key == 'timesignature':
550
+ metadata['timesignature'] = value
551
 
552
  return metadata, audio_codes
553
 
acestep/third_parts/nano-vllm/nanovllm/config.py CHANGED
@@ -1,8 +1,35 @@
1
  import os
 
2
  from dataclasses import dataclass
3
  from transformers import AutoConfig
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  @dataclass
7
  class Config:
8
  model: str
@@ -13,9 +40,10 @@ class Config:
13
  tensor_parallel_size: int = 1
14
  enforce_eager: bool = False
15
  hf_config: AutoConfig | None = None
16
- eos: int = -1
17
  kvcache_block_size: int = 256
18
  num_kvcache_blocks: int = -1
 
19
 
20
  def __post_init__(self):
21
  assert os.path.isdir(self.model)
@@ -24,3 +52,6 @@ class Config:
24
  self.hf_config = AutoConfig.from_pretrained(self.model)
25
  self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
26
  assert self.max_num_batched_tokens >= self.max_model_len
 
 
 
 
1
  import os
2
+ import socket
3
  from dataclasses import dataclass
4
  from transformers import AutoConfig
5
 
6
 
7
+ def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
8
+ """Find an available port starting from start_port.
9
+
10
+ Args:
11
+ start_port: The starting port number to check
12
+ max_attempts: Maximum number of ports to try
13
+
14
+ Returns:
15
+ An available port number
16
+
17
+ Raises:
18
+ RuntimeError: If no available port is found within max_attempts
19
+ """
20
+ for i in range(max_attempts):
21
+ port = start_port + i
22
+ try:
23
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
24
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
25
+ s.bind(('localhost', port))
26
+ return port
27
+ except OSError:
28
+ # Port is in use, try next one
29
+ continue
30
+ raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
31
+
32
+
33
  @dataclass
34
  class Config:
35
  model: str
 
40
  tensor_parallel_size: int = 1
41
  enforce_eager: bool = False
42
  hf_config: AutoConfig | None = None
43
+ eos: int = 151643
44
  kvcache_block_size: int = 256
45
  num_kvcache_blocks: int = -1
46
+ dist_port: int | None = None
47
 
48
  def __post_init__(self):
49
  assert os.path.isdir(self.model)
 
52
  self.hf_config = AutoConfig.from_pretrained(self.model)
53
  self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
54
  assert self.max_num_batched_tokens >= self.max_model_len
55
+ # Auto-find available port if not specified
56
+ if self.dist_port is None:
57
+ self.dist_port = find_available_port()
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py CHANGED
@@ -21,6 +21,28 @@ class LLMEngine:
21
  self.ps = []
22
  self.events = []
23
  ctx = mp.get_context("spawn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  for i in range(1, config.tensor_parallel_size):
25
  event = ctx.Event()
26
  process = ctx.Process(target=ModelRunner, args=(config, i, event))
@@ -28,8 +50,7 @@ class LLMEngine:
28
  self.ps.append(process)
29
  self.events.append(event)
30
  self.model_runner = ModelRunner(config, 0, self.events)
31
- self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
32
- config.eos = self.tokenizer.eos_token_id
33
  self.scheduler = Scheduler(config)
34
  atexit.register(self.exit)
35
 
 
21
  self.ps = []
22
  self.events = []
23
  ctx = mp.get_context("spawn")
24
+
25
+ # Pre-validate port availability by attempting to bind to it
26
+ # This helps avoid race conditions when multiple LLMEngine instances start simultaneously
27
+ import socket
28
+ from nanovllm.config import find_available_port
29
+ max_port_retries = 10
30
+ for port_attempt in range(max_port_retries):
31
+ try:
32
+ # Test if port is actually available by binding to it
33
+ test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
34
+ test_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
35
+ test_socket.bind(('localhost', config.dist_port))
36
+ test_socket.close()
37
+ # Port is available, break
38
+ break
39
+ except OSError:
40
+ # Port is in use, find next available
41
+ if port_attempt < max_port_retries - 1:
42
+ config.dist_port = find_available_port(start_port=config.dist_port + 1, max_attempts=10)
43
+ else:
44
+ raise RuntimeError(f"Failed to find available port after {max_port_retries} attempts")
45
+
46
  for i in range(1, config.tensor_parallel_size):
47
  event = ctx.Event()
48
  process = ctx.Process(target=ModelRunner, args=(config, i, event))
 
50
  self.ps.append(process)
51
  self.events.append(event)
52
  self.model_runner = ModelRunner(config, 0, self.events)
53
+ self.tokenizer = None
 
54
  self.scheduler = Scheduler(config)
55
  atexit.register(self.exit)
56
 
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -1,10 +1,11 @@
1
  import pickle
 
2
  import torch
3
  import torch.distributed as dist
4
  from multiprocessing.synchronize import Event
5
  from multiprocessing.shared_memory import SharedMemory
6
 
7
- from nanovllm.config import Config
8
  from nanovllm.engine.sequence import Sequence
9
  from nanovllm.models.qwen3 import Qwen3ForCausalLM
10
  from nanovllm.layers.sampler import Sampler
@@ -23,7 +24,32 @@ class ModelRunner:
23
  self.rank = rank
24
  self.event = event
25
 
26
- dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  torch.cuda.set_device(rank)
28
  default_dtype = torch.get_default_dtype()
29
  torch.set_default_dtype(hf_config.torch_dtype)
@@ -118,9 +144,15 @@ class ModelRunner:
118
  layer_id += 1
119
 
120
  def prepare_block_tables(self, seqs: list[Sequence]):
121
- max_len = max(len(seq.block_table) for seq in seqs)
 
 
 
122
  block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
123
  block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
 
 
 
124
  return block_tables
125
 
126
  def prepare_prefill(self, seqs: list[Sequence]):
@@ -215,7 +247,29 @@ class ModelRunner:
215
  graph_vars["slot_mapping"][:bs] = context.slot_mapping
216
  graph_vars["context_lens"].zero_()
217
  graph_vars["context_lens"][:bs] = context.context_lens
218
- graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  graph.replay()
220
  return self.model.compute_logits(graph_vars["outputs"][:bs])
221
 
 
1
  import pickle
2
+ import socket
3
  import torch
4
  import torch.distributed as dist
5
  from multiprocessing.synchronize import Event
6
  from multiprocessing.shared_memory import SharedMemory
7
 
8
+ from nanovllm.config import Config, find_available_port
9
  from nanovllm.engine.sequence import Sequence
10
  from nanovllm.models.qwen3 import Qwen3ForCausalLM
11
  from nanovllm.layers.sampler import Sampler
 
24
  self.rank = rank
25
  self.event = event
26
 
27
+ # Try to initialize process group with retry logic for port conflicts
28
+ # Only rank 0 binds to the port, so only rank 0 needs retry logic
29
+ dist_port = self.config.dist_port
30
+ max_retries = 10
31
+ for attempt in range(max_retries):
32
+ try:
33
+ dist.init_process_group("nccl", f"tcp://localhost:{dist_port}", world_size=self.world_size, rank=rank)
34
+ break
35
+ except RuntimeError as e:
36
+ if ("EADDRINUSE" in str(e) or "address already in use" in str(e).lower()) and rank == 0:
37
+ # Port is in use, try next port (only for rank 0)
38
+ if attempt < max_retries - 1:
39
+ # Find next available port
40
+ dist_port = find_available_port(start_port=dist_port + 1, max_attempts=10)
41
+ self.config.dist_port = dist_port
42
+ # If we had a previous failed attempt, destroy any partial process group
43
+ if dist.is_initialized():
44
+ try:
45
+ dist.destroy_process_group()
46
+ except:
47
+ pass
48
+ else:
49
+ raise RuntimeError(f"Failed to find available port after {max_retries} attempts. Last error: {e}")
50
+ else:
51
+ # Other error or non-rank-0 process, re-raise
52
+ raise
53
  torch.cuda.set_device(rank)
54
  default_dtype = torch.get_default_dtype()
55
  torch.set_default_dtype(hf_config.torch_dtype)
 
144
  layer_id += 1
145
 
146
  def prepare_block_tables(self, seqs: list[Sequence]):
147
+ max_len = max(len(seq.block_table) for seq in seqs) if seqs else 0
148
+ if max_len == 0:
149
+ # Return empty 2D tensor with correct shape
150
+ return torch.zeros((len(seqs), 0), dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
151
  block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
152
  block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
153
+ # Ensure it's 2D: if only one sequence, shape should be [1, max_len]
154
+ if block_tables.dim() == 1:
155
+ block_tables = block_tables.unsqueeze(0)
156
  return block_tables
157
 
158
  def prepare_prefill(self, seqs: list[Sequence]):
 
247
  graph_vars["slot_mapping"][:bs] = context.slot_mapping
248
  graph_vars["context_lens"].zero_()
249
  graph_vars["context_lens"][:bs] = context.context_lens
250
+ # Handle block_tables: ensure it's 2D and size matches
251
+ if context.block_tables is not None and context.block_tables.numel() > 0:
252
+ # Ensure block_tables is 2D
253
+ if context.block_tables.dim() == 1:
254
+ # Reshape 1D to 2D: [num_blocks] -> [1, num_blocks]
255
+ block_tables_2d = context.block_tables.unsqueeze(0)
256
+ else:
257
+ block_tables_2d = context.block_tables
258
+
259
+ # Get dimensions
260
+ context_bs = block_tables_2d.size(0)
261
+ context_num_blocks = block_tables_2d.size(1)
262
+ graph_num_blocks = graph_vars["block_tables"].size(1)
263
+
264
+ # Use minimum to avoid size mismatch
265
+ num_blocks_to_copy = min(context_num_blocks, graph_num_blocks)
266
+ actual_bs = min(bs, context_bs)
267
+
268
+ # Copy block_tables with size matching
269
+ graph_vars["block_tables"][:actual_bs, :num_blocks_to_copy] = block_tables_2d[:actual_bs, :num_blocks_to_copy]
270
+ # Fill remaining with -1 if needed
271
+ if num_blocks_to_copy < graph_num_blocks:
272
+ graph_vars["block_tables"][:actual_bs, num_blocks_to_copy:] = -1
273
  graph.replay()
274
  return self.model.compute_logits(graph_vars["outputs"][:bs])
275
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ diffusers
4
  gradio
5
  soundfile
6
  loguru
7
- einops
 
 
4
  gradio
5
  soundfile
6
  loguru
7
+ einops
8
+ accelerator