ChuxiJ commited on
Commit
76de6b9
·
1 Parent(s): e19bc36

fix duration

Browse files
acestep/gradio_ui/interfaces/generation.py CHANGED
@@ -364,7 +364,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
364
  minimum=-1,
365
  maximum=600.0,
366
  step=1,
367
- info="Use -1 for random",
368
  scale=1,
369
  )
370
  vocal_language = gr.Dropdown(
 
364
  minimum=-1,
365
  maximum=600.0,
366
  step=1,
367
+ info="Use -1 for auto, or 10-600 seconds",
368
  scale=1,
369
  )
370
  vocal_language = gr.Dropdown(
acestep/inference.py CHANGED
@@ -447,6 +447,7 @@ def generate_music(
447
  negative_prompt=params.lm_negative_prompt,
448
  top_k=top_k_value,
449
  top_p=top_p_value,
 
450
  user_metadata=user_metadata_to_pass,
451
  use_cot_caption=params.use_cot_caption,
452
  use_cot_language=params.use_cot_language,
 
447
  negative_prompt=params.lm_negative_prompt,
448
  top_k=top_k_value,
449
  top_p=top_p_value,
450
+ target_duration=audio_duration, # Pass duration to limit audio codes generation
451
  user_metadata=user_metadata_to_pass,
452
  use_cot_caption=params.use_cot_caption,
453
  use_cot_language=params.use_cot_language,
acestep/llm_inference.py CHANGED
@@ -474,8 +474,20 @@ class LLMHandler:
474
  codes_temperature=codes_temperature,
475
  )
476
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  sampling_params = SamplingParams(
478
- max_tokens=self.max_model_len - 64,
479
  temperature=effective_sampler_temp,
480
  cfg_scale=cfg_scale,
481
  top_k=top_k,
@@ -566,7 +578,17 @@ class LLMHandler:
566
 
567
  with self._load_model_context():
568
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
569
- max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096)
 
 
 
 
 
 
 
 
 
 
570
  if hasattr(self, "max_model_len"):
571
  max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
572
 
@@ -1927,6 +1949,18 @@ class LLMHandler:
1927
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
1928
 
1929
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
1930
  return "", f"❌ Error generating from formatted prompt: {e}"
1931
 
1932
  def _generate_with_constrained_decoding(
 
474
  codes_temperature=codes_temperature,
475
  )
476
 
477
+ # Calculate max_tokens based on target_duration if specified
478
+ # 5 audio codes = 1 second, plus ~500 tokens for CoT metadata and safety margin
479
+ if target_duration is not None and target_duration > 0:
480
+ # Ensure duration is within valid range (10-600 seconds)
481
+ effective_duration = max(10, min(600, target_duration))
482
+ max_tokens = int(effective_duration * 5) + 500
483
+ # Cap at model's max length
484
+ max_tokens = min(max_tokens, self.max_model_len - 64)
485
+ else:
486
+ # No duration constraint - use default (model will stop at EOS naturally)
487
+ max_tokens = self.max_model_len - 64
488
+
489
  sampling_params = SamplingParams(
490
+ max_tokens=max_tokens,
491
  temperature=effective_sampler_temp,
492
  cfg_scale=cfg_scale,
493
  top_k=top_k,
 
578
 
579
  with self._load_model_context():
580
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
581
+
582
+ # Calculate max_new_tokens based on target_duration if specified
583
+ # 5 audio codes = 1 second, plus ~500 tokens for CoT metadata and safety margin
584
+ if target_duration is not None and target_duration > 0:
585
+ # Ensure duration is within valid range (10-600 seconds)
586
+ effective_duration = max(10, min(600, target_duration))
587
+ max_new_tokens = int(effective_duration * 5) + 500
588
+ else:
589
+ max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096)
590
+
591
+ # Cap at model's max length
592
  if hasattr(self, "max_model_len"):
593
  max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
594
 
 
1949
  return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
1950
 
1951
  except Exception as e:
1952
+ # Reset nano-vllm state on error to prevent stale context from causing
1953
+ # subsequent CUDA illegal memory access errors
1954
+ if self.llm_backend == "vllm":
1955
+ try:
1956
+ from nanovllm.utils.context import reset_context
1957
+ reset_context()
1958
+ except ImportError:
1959
+ pass
1960
+ # Clear CUDA cache to release any corrupted memory
1961
+ if torch.cuda.is_available():
1962
+ torch.cuda.empty_cache()
1963
+ torch.cuda.synchronize()
1964
  return "", f"❌ Error generating from formatted prompt: {e}"
1965
 
1966
  def _generate_with_constrained_decoding(
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py CHANGED
@@ -86,6 +86,13 @@ class BlockManager:
86
  block = self.blocks[block_id]
87
  block.ref_count -= 1
88
  if block.ref_count == 0:
 
 
 
 
 
 
 
89
  self._deallocate_block(block_id)
90
  seq.num_cached_tokens = 0
91
  seq.block_table.clear()
 
86
  block = self.blocks[block_id]
87
  block.ref_count -= 1
88
  if block.ref_count == 0:
89
+ # Fix: Clean up hash_to_block_id mapping to prevent stale references
90
+ # This prevents CUDA illegal memory access when prefix cache tries to
91
+ # reuse a block_id that has already been freed
92
+ if block.hash != -1:
93
+ cached_id = self.hash_to_block_id.get(block.hash)
94
+ if cached_id == block_id:
95
+ del self.hash_to_block_id[block.hash]
96
  self._deallocate_block(block_id)
97
  seq.num_cached_tokens = 0
98
  seq.block_table.clear()
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -325,6 +325,17 @@ class ModelRunner:
325
  # Fall back to eager mode when block_tables is too large for CUDA graph
326
  return self.model.compute_logits(self.model(input_ids, positions))
327
 
 
 
 
 
 
 
 
 
 
 
 
328
  graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
329
  graph_vars = self.graph_vars
330
  graph_vars["input_ids"][:bs] = input_ids
@@ -416,9 +427,10 @@ class ModelRunner:
416
  ).tolist()
417
 
418
  # Update logits processor state after sampling
419
- for i, seq in enumerate(cond_seqs):
420
- if seq.logits_processor_update_state is not None:
421
- seq.logits_processor_update_state(token_ids_cfg[i])
 
422
 
423
  # Return token_ids (will be applied to both conditional and unconditional sequences)
424
  return token_ids_cfg
@@ -483,9 +495,11 @@ class ModelRunner:
483
  ).tolist()
484
 
485
  # Update logits processor state after sampling
486
- for i, seq in enumerate(seqs):
487
- if seq.logits_processor_update_state is not None:
488
- seq.logits_processor_update_state(token_ids[i])
 
 
489
 
490
  return token_ids
491
  else:
 
325
  # Fall back to eager mode when block_tables is too large for CUDA graph
326
  return self.model.compute_logits(self.model(input_ids, positions))
327
 
328
+ # Fix: Also check if block_tables row count matches batch size
329
+ # Dimension mismatch can cause CUDA illegal memory access during graph replay
330
+ if context.block_tables.size(0) != bs:
331
+ # Fall back to eager mode when block_tables row count doesn't match batch size
332
+ return self.model.compute_logits(self.model(input_ids, positions))
333
+
334
+ # Fix: Verify slot_mapping and context_lens dimensions match batch size
335
+ if context.slot_mapping.size(0) != bs or context.context_lens.size(0) != bs:
336
+ # Fall back to eager mode when dimensions don't match
337
+ return self.model.compute_logits(self.model(input_ids, positions))
338
+
339
  graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
340
  graph_vars = self.graph_vars
341
  graph_vars["input_ids"][:bs] = input_ids
 
427
  ).tolist()
428
 
429
  # Update logits processor state after sampling
430
+ # NOTE: Only update for the first sequence since all sequences share the same processor
431
+ # Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
432
+ if cond_seqs and cond_seqs[0].logits_processor_update_state is not None:
433
+ cond_seqs[0].logits_processor_update_state(token_ids_cfg[0])
434
 
435
  # Return token_ids (will be applied to both conditional and unconditional sequences)
436
  return token_ids_cfg
 
495
  ).tolist()
496
 
497
  # Update logits processor state after sampling
498
+ # NOTE: Only update for the first sequence since all sequences may share the same processor
499
+ # (when using a single SamplingParams for batch generation)
500
+ # Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
501
+ if seqs and seqs[0].logits_processor_update_state is not None:
502
+ seqs[0].logits_processor_update_state(token_ids[0])
503
 
504
  return token_ids
505
  else: