ChuxiJ commited on
Commit
7f5c13a
·
1 Parent(s): ba7469b

full meta input

Browse files
acestep/constrained_logits_processor.py CHANGED
@@ -1662,16 +1662,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1662
  elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
1663
  # Accumulate numeric value using token ID sequence
1664
  if generated_token_id == self.newline_token:
1665
- if self.state == FSMState.DURATION_VALUE and self.accumulated_value:
1666
- try:
1667
- generated_duration = int(self.accumulated_value)
1668
- if self.target_codes is None and generated_duration > 0:
1669
- self.target_codes = int(generated_duration * 5)
1670
- if self.debug:
1671
- logger.debug(f"Synced duration: {generated_duration}s -> Set target_codes limit to {self.target_codes}")
1672
- except ValueError:
1673
- if self.debug:
1674
- logger.warning(f"Could not parse duration value: {self.accumulated_value}")
1675
  # Newline ends the field
1676
  # Save old state before transition
1677
  old_state = self.state
 
1662
  elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
1663
  # Accumulate numeric value using token ID sequence
1664
  if generated_token_id == self.newline_token:
1665
+ # if self.state == FSMState.DURATION_VALUE and self.accumulated_value:
1666
+ # try:
1667
+ # generated_duration = int(self.accumulated_value)
1668
+ # if self.target_codes is None and generated_duration > 0:
1669
+ # self.target_codes = int(generated_duration * 5)
1670
+ # if self.debug:
1671
+ # logger.debug(f"Synced duration: {generated_duration}s -> Set target_codes limit to {self.target_codes}")
1672
+ # except ValueError:
1673
+ # if self.debug:
1674
+ # logger.warning(f"Could not parse duration value: {self.accumulated_value}")
1675
  # Newline ends the field
1676
  # Save old state before transition
1677
  old_state = self.state
acestep/llm_inference.py CHANGED
@@ -448,6 +448,14 @@ class LLMHandler:
448
 
449
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
450
  return output_text
 
 
 
 
 
 
 
 
451
 
452
  def generate_with_stop_condition(
453
  self,
@@ -485,28 +493,32 @@ class LLMHandler:
485
 
486
  # Determine stop condition
487
  stop_at_reasoning = (infer_type == "dit")
488
- # For llm_dit mode: use normal generation (stops at EOS)
489
- output_text, status = self.generate_from_formatted_prompt(
490
- formatted_prompt=formatted_prompt,
491
- cfg={
492
- "temperature": temperature,
493
- "cfg_scale": cfg_scale,
494
- "negative_prompt": negative_prompt,
495
- "top_k": top_k,
496
- "top_p": top_p,
497
- "repetition_penalty": repetition_penalty,
498
- "target_duration": target_duration,
499
- "user_metadata": user_metadata,
500
- },
501
- use_constrained_decoding=use_constrained_decoding,
502
- constrained_decoding_debug=constrained_decoding_debug,
503
- stop_at_reasoning=stop_at_reasoning,
504
- )
505
- if not output_text:
506
- return {}, "", status
 
 
 
 
507
 
508
- # Parse output
509
- metadata, audio_codes = self.parse_lm_output(output_text)
510
 
511
  codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
512
  status_msg = f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
 
448
 
449
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
450
  return output_text
451
+
452
+ def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
453
+ """Check if all required metadata are present."""
454
+ if user_metadata is None:
455
+ return False
456
+ if 'bpm' in user_metadata and 'keyscale' in user_metadata and 'timesignature' in user_metadata and 'duration' in user_metadata and 'genres' in user_metadata:
457
+ return True
458
+ return False
459
 
460
  def generate_with_stop_condition(
461
  self,
 
493
 
494
  # Determine stop condition
495
  stop_at_reasoning = (infer_type == "dit")
496
+ has_all_metas = self.has_all_metas(user_metadata)
497
+ audio_codes = ""
498
+
499
+ if not has_all_metas or not stop_at_reasoning:
500
+ # For llm_dit mode: use normal generation (stops at EOS)
501
+ output_text, status = self.generate_from_formatted_prompt(
502
+ formatted_prompt=formatted_prompt,
503
+ cfg={
504
+ "temperature": temperature,
505
+ "cfg_scale": cfg_scale,
506
+ "negative_prompt": negative_prompt,
507
+ "top_k": top_k,
508
+ "top_p": top_p,
509
+ "repetition_penalty": repetition_penalty,
510
+ "target_duration": target_duration,
511
+ "user_metadata": user_metadata,
512
+ },
513
+ use_constrained_decoding=use_constrained_decoding,
514
+ constrained_decoding_debug=constrained_decoding_debug,
515
+ stop_at_reasoning=stop_at_reasoning,
516
+ )
517
+ if not output_text:
518
+ return {}, "", status
519
 
520
+ # Parse output
521
+ metadata, audio_codes = self.parse_lm_output(output_text)
522
 
523
  codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
524
  status_msg = f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"