Spaces:
Running
on
A100
Running
on
A100
full meta input
Browse files- acestep/constrained_logits_processor.py +10 -10
- acestep/llm_inference.py +33 -21
acestep/constrained_logits_processor.py
CHANGED
|
@@ -1662,16 +1662,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
|
|
| 1662 |
elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
|
| 1663 |
# Accumulate numeric value using token ID sequence
|
| 1664 |
if generated_token_id == self.newline_token:
|
| 1665 |
-
if self.state == FSMState.DURATION_VALUE and self.accumulated_value:
|
| 1666 |
-
|
| 1667 |
-
|
| 1668 |
-
|
| 1669 |
-
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
|
| 1675 |
# Newline ends the field
|
| 1676 |
# Save old state before transition
|
| 1677 |
old_state = self.state
|
|
|
|
| 1662 |
elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]:
|
| 1663 |
# Accumulate numeric value using token ID sequence
|
| 1664 |
if generated_token_id == self.newline_token:
|
| 1665 |
+
# if self.state == FSMState.DURATION_VALUE and self.accumulated_value:
|
| 1666 |
+
# try:
|
| 1667 |
+
# generated_duration = int(self.accumulated_value)
|
| 1668 |
+
# if self.target_codes is None and generated_duration > 0:
|
| 1669 |
+
# self.target_codes = int(generated_duration * 5)
|
| 1670 |
+
# if self.debug:
|
| 1671 |
+
# logger.debug(f"Synced duration: {generated_duration}s -> Set target_codes limit to {self.target_codes}")
|
| 1672 |
+
# except ValueError:
|
| 1673 |
+
# if self.debug:
|
| 1674 |
+
# logger.warning(f"Could not parse duration value: {self.accumulated_value}")
|
| 1675 |
# Newline ends the field
|
| 1676 |
# Save old state before transition
|
| 1677 |
old_state = self.state
|
acestep/llm_inference.py
CHANGED
|
@@ -448,6 +448,14 @@ class LLMHandler:
|
|
| 448 |
|
| 449 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 450 |
return output_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
def generate_with_stop_condition(
|
| 453 |
self,
|
|
@@ -485,28 +493,32 @@ class LLMHandler:
|
|
| 485 |
|
| 486 |
# Determine stop condition
|
| 487 |
stop_at_reasoning = (infer_type == "dit")
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
-
|
| 509 |
-
|
| 510 |
|
| 511 |
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 512 |
status_msg = f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
|
|
|
|
| 448 |
|
| 449 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 450 |
return output_text
|
| 451 |
+
|
| 452 |
+
def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
|
| 453 |
+
"""Check if all required metadata are present."""
|
| 454 |
+
if user_metadata is None:
|
| 455 |
+
return False
|
| 456 |
+
if 'bpm' in user_metadata and 'keyscale' in user_metadata and 'timesignature' in user_metadata and 'duration' in user_metadata and 'genres' in user_metadata:
|
| 457 |
+
return True
|
| 458 |
+
return False
|
| 459 |
|
| 460 |
def generate_with_stop_condition(
|
| 461 |
self,
|
|
|
|
| 493 |
|
| 494 |
# Determine stop condition
|
| 495 |
stop_at_reasoning = (infer_type == "dit")
|
| 496 |
+
has_all_metas = self.has_all_metas(user_metadata)
|
| 497 |
+
audio_codes = ""
|
| 498 |
+
|
| 499 |
+
if not has_all_metas or not stop_at_reasoning:
|
| 500 |
+
# For llm_dit mode: use normal generation (stops at EOS)
|
| 501 |
+
output_text, status = self.generate_from_formatted_prompt(
|
| 502 |
+
formatted_prompt=formatted_prompt,
|
| 503 |
+
cfg={
|
| 504 |
+
"temperature": temperature,
|
| 505 |
+
"cfg_scale": cfg_scale,
|
| 506 |
+
"negative_prompt": negative_prompt,
|
| 507 |
+
"top_k": top_k,
|
| 508 |
+
"top_p": top_p,
|
| 509 |
+
"repetition_penalty": repetition_penalty,
|
| 510 |
+
"target_duration": target_duration,
|
| 511 |
+
"user_metadata": user_metadata,
|
| 512 |
+
},
|
| 513 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 514 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 515 |
+
stop_at_reasoning=stop_at_reasoning,
|
| 516 |
+
)
|
| 517 |
+
if not output_text:
|
| 518 |
+
return {}, "", status
|
| 519 |
|
| 520 |
+
# Parse output
|
| 521 |
+
metadata, audio_codes = self.parse_lm_output(output_text)
|
| 522 |
|
| 523 |
codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
|
| 524 |
status_msg = f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
|