xushengyuan commited on
Commit
1bdda7b
·
1 Parent(s): 447806b

add 5hz llm test support & fix 5hz llm transformers inference

Browse files
Files changed (2) hide show
  1. acestep/handler.py +47 -19
  2. test.py +52 -6
acestep/handler.py CHANGED
@@ -20,7 +20,8 @@ from tqdm import tqdm
20
  from loguru import logger
21
  import warnings
22
 
23
- from transformers import AutoTokenizer, AutoModel
 
24
  from diffusers.models import AutoencoderOobleck
25
 
26
 
@@ -175,6 +176,8 @@ class AceStepHandler:
175
  try:
176
  if device == "auto":
177
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
178
 
179
  self.device = device
180
  self.offload_to_cpu = offload_to_cpu
@@ -203,7 +206,6 @@ class AceStepHandler:
203
  self.model = AutoModel.from_pretrained(
204
  acestep_v15_checkpoint_path,
205
  trust_remote_code=True,
206
- dtype=self.dtype,
207
  attn_implementation=attn_implementation
208
  )
209
  except Exception as e:
@@ -214,7 +216,6 @@ class AceStepHandler:
214
  self.model = AutoModel.from_pretrained(
215
  acestep_v15_checkpoint_path,
216
  trust_remote_code=True,
217
- dtype=self.dtype,
218
  attn_implementation=attn_implementation
219
  )
220
  else:
@@ -299,8 +300,11 @@ class AceStepHandler:
299
  # vllm initialization failed, fallback to PyTorch
300
  if not self.llm_initialized:
301
  try:
302
- self.llm = AutoModel.from_pretrained(full_lm_model_path)
303
- self.llm = self.llm.to(device).to(self.dtype)
 
 
 
304
  self.llm.eval()
305
  self.llm_backend = "pt"
306
  self.llm_initialized = True
@@ -311,9 +315,12 @@ class AceStepHandler:
311
  else:
312
  # For CPU or other devices, use PyTorch backend
313
  try:
314
- self.llm = AutoModel.from_pretrained(full_lm_model_path)
315
- self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
316
- self.llm = self.llm.to(device).to(self.dtype)
 
 
 
317
  self.llm.eval()
318
  self.llm_backend = "pt"
319
  self.llm_initialized = True
@@ -328,7 +335,7 @@ class AceStepHandler:
328
  # Determine actual attention implementation used
329
  actual_attn = getattr(self.config, "_attn_implementation", "eager")
330
 
331
- status_msg = f"✅ Model initialized successfully on {device}\n"
332
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
333
  status_msg += f"VAE: {vae_checkpoint_path}\n"
334
  status_msg += f"Text encoder: {text_encoder_path}\n"
@@ -581,22 +588,43 @@ class AceStepHandler:
581
  padding=False,
582
  truncation=True,
583
  )
584
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
585
 
586
  # Generate with the model
587
- with torch.no_grad():
 
 
588
  # Get max_new_tokens from model config or use a default
589
  max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
590
  if hasattr(self, 'max_model_len'):
591
  max_new_tokens = min(max_new_tokens, self.max_model_len)
592
 
593
- outputs = self.llm.generate(
594
- **inputs,
595
- max_new_tokens=max_new_tokens,
596
- temperature=temperature,
597
- do_sample=True if temperature > 0 else False,
598
- pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
599
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  # Decode the generated tokens
602
  # Only decode the newly generated tokens (skip the input prompt)
@@ -776,7 +804,7 @@ class AceStepHandler:
776
  # Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
777
  if indices.dim() == 2:
778
  indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
779
-
780
  # Get quantized representation from indices: [1, T_5Hz, dim]
781
  quantized = quantizer.get_output_from_indices(indices)
782
  if quantized.dtype != self.dtype:
 
20
  from loguru import logger
21
  import warnings
22
 
23
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
24
+ from transformers.generation.streamers import BaseStreamer
25
  from diffusers.models import AutoencoderOobleck
26
 
27
 
 
176
  try:
177
  if device == "auto":
178
  device = "cuda" if torch.cuda.is_available() else "cpu"
179
+
180
+ status_msg = ""
181
 
182
  self.device = device
183
  self.offload_to_cpu = offload_to_cpu
 
206
  self.model = AutoModel.from_pretrained(
207
  acestep_v15_checkpoint_path,
208
  trust_remote_code=True,
 
209
  attn_implementation=attn_implementation
210
  )
211
  except Exception as e:
 
216
  self.model = AutoModel.from_pretrained(
217
  acestep_v15_checkpoint_path,
218
  trust_remote_code=True,
 
219
  attn_implementation=attn_implementation
220
  )
221
  else:
 
300
  # vllm initialization failed, fallback to PyTorch
301
  if not self.llm_initialized:
302
  try:
303
+ self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
304
+ if not self.offload_to_cpu:
305
+ self.llm = self.llm.to(device).to(self.dtype)
306
+ else:
307
+ self.llm = self.llm.to("cpu").to(self.dtype)
308
  self.llm.eval()
309
  self.llm_backend = "pt"
310
  self.llm_initialized = True
 
315
  else:
316
  # For CPU or other devices, use PyTorch backend
317
  try:
318
+ self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
319
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True, trust_remote_code=True)
320
+ if not self.offload_to_cpu:
321
+ self.llm = self.llm.to(device).to(self.dtype)
322
+ else:
323
+ self.llm = self.llm.to("cpu").to(self.dtype)
324
  self.llm.eval()
325
  self.llm_backend = "pt"
326
  self.llm_initialized = True
 
335
  # Determine actual attention implementation used
336
  actual_attn = getattr(self.config, "_attn_implementation", "eager")
337
 
338
+ status_msg = f"✅ Model initialized successfully on {device}\n" + status_msg
339
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
340
  status_msg += f"VAE: {vae_checkpoint_path}\n"
341
  status_msg += f"Text encoder: {text_encoder_path}\n"
 
588
  padding=False,
589
  truncation=True,
590
  )
 
591
 
592
  # Generate with the model
593
+ with self._load_model_context("llm"):
594
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
595
+
596
  # Get max_new_tokens from model config or use a default
597
  max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
598
  if hasattr(self, 'max_model_len'):
599
  max_new_tokens = min(max_new_tokens, self.max_model_len)
600
 
601
+ # Define custom streamer for tqdm
602
+ class TqdmTokenStreamer(BaseStreamer):
603
+ def __init__(self, total):
604
+ self.pbar = tqdm(total=total, desc="Generating 5Hz tokens", unit="token", maxinterval=1)
605
+
606
+ def put(self, value):
607
+ # value is tensor of token ids
608
+ if value.dim() > 1:
609
+ num_tokens = value.numel()
610
+ else:
611
+ num_tokens = len(value)
612
+ self.pbar.update(num_tokens)
613
+
614
+ def end(self):
615
+ self.pbar.close()
616
+
617
+ streamer = TqdmTokenStreamer(total=max_new_tokens)
618
+
619
+ with torch.no_grad():
620
+ outputs = self.llm.generate(
621
+ **inputs,
622
+ max_new_tokens=max_new_tokens,
623
+ temperature=temperature,
624
+ do_sample=True if temperature > 0 else False,
625
+ pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
626
+ streamer=streamer,
627
+ )
628
 
629
  # Decode the generated tokens
630
  # Only decode the newly generated tokens (skip the input prompt)
 
804
  # Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
805
  if indices.dim() == 2:
806
  indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
807
+ print(indices.shape)
808
  # Get quantized representation from indices: [1, T_5Hz, dim]
809
  quantized = quantizer.get_output_from_indices(indices)
810
  if quantized.dtype != self.dtype:
test.py CHANGED
@@ -35,13 +35,15 @@ def main():
35
  device = "xpu"
36
  print(f"Using device: {device}")
37
 
 
 
38
  status, enabled = handler.initialize_service(
39
  project_root=project_root,
40
  config_path=model_name,
41
  device=device,
42
- init_llm=True,
43
  use_flash_attention=False, # Default in UI
44
- compile_model=False,
45
  offload_to_cpu=True,
46
  offload_dit_to_cpu=False, # Keep DiT on GPU
47
  )
@@ -95,6 +97,49 @@ def main():
95
 
96
  print("Starting generation...")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # Reset peak memory stats
99
  if hasattr(torch, 'xpu') and torch.xpu.is_available():
100
  torch.xpu.reset_peak_memory_stats()
@@ -105,21 +150,22 @@ def main():
105
  results = handler.generate_music(
106
  captions=captions,
107
  lyrics=lyrics,
108
- bpm=90,
109
- key_scale="A major",
110
- time_signature="4",
111
  vocal_language="zh",
112
  inference_steps=8,
113
  guidance_scale=7.0,
114
  use_random_seed=False,
115
  seed=seeds,
116
- audio_duration=120,
117
  batch_size=1,
118
  task_type="text2music",
119
  cfg_interval_start=0.0,
120
  cfg_interval_end=0.95,
121
  audio_format="wav",
122
  use_tiled_decode=True,
 
123
  )
124
 
125
  # Unpack results
 
35
  device = "xpu"
36
  print(f"Using device: {device}")
37
 
38
+ use_llm = False
39
+
40
  status, enabled = handler.initialize_service(
41
  project_root=project_root,
42
  config_path=model_name,
43
  device=device,
44
+ init_llm=use_llm,
45
  use_flash_attention=False, # Default in UI
46
+ compile_model=True,
47
  offload_to_cpu=True,
48
  offload_dit_to_cpu=False, # Keep DiT on GPU
49
  )
 
97
 
98
  print("Starting generation...")
99
 
100
+ # Generate hints using 5Hz LLM
101
+ if use_llm:
102
+ print("Generating hints using 5Hz LLM...")
103
+ lm_temperature = 0.6
104
+ metadata, audio_codes, lm_status = handler.generate_with_5hz_lm(captions, lyrics, lm_temperature)
105
+ print(f"5Hz LLM Status: {lm_status}")
106
+ print(f"Generated Metadata: {metadata}")
107
+ print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
108
+ else:
109
+ print("Skipping 5Hz LLM generation...")
110
+ metadata = {}
111
+ audio_codes = None
112
+ lm_status = "Skipped"
113
+
114
+ # Use generated metadata if available
115
+ bpm = metadata.get('bpm', 90)
116
+ if bpm == "N/A" or bpm == "":
117
+ bpm = 90
118
+ else:
119
+ try:
120
+ bpm = int(float(bpm))
121
+ except:
122
+ bpm = 90
123
+
124
+ key_scale = metadata.get('keyscale', metadata.get('key_scale', "A major"))
125
+ if key_scale == "N/A":
126
+ key_scale = "A major"
127
+
128
+ time_signature = metadata.get('timesignature', metadata.get('time_signature', "4"))
129
+ if time_signature == "N/A":
130
+ time_signature = "4"
131
+
132
+ audio_duration = metadata.get('duration', 120)
133
+ if audio_duration == "N/A":
134
+ audio_duration = 120
135
+ else:
136
+ try:
137
+ audio_duration = float(audio_duration)
138
+ except:
139
+ audio_duration = 120
140
+
141
+ print(f"Using parameters: BPM={bpm}, Key={key_scale}, Time Sig={time_signature}, Duration={audio_duration}")
142
+
143
  # Reset peak memory stats
144
  if hasattr(torch, 'xpu') and torch.xpu.is_available():
145
  torch.xpu.reset_peak_memory_stats()
 
150
  results = handler.generate_music(
151
  captions=captions,
152
  lyrics=lyrics,
153
+ bpm=bpm,
154
+ key_scale=key_scale,
155
+ time_signature=time_signature,
156
  vocal_language="zh",
157
  inference_steps=8,
158
  guidance_scale=7.0,
159
  use_random_seed=False,
160
  seed=seeds,
161
+ audio_duration=audio_duration,
162
  batch_size=1,
163
  task_type="text2music",
164
  cfg_interval_start=0.0,
165
  cfg_interval_end=0.95,
166
  audio_format="wav",
167
  use_tiled_decode=True,
168
+ audio_code_string=audio_codes,
169
  )
170
 
171
  # Unpack results