xushengyuan commited on
Commit
1d23edb
·
1 Parent(s): 497ca57

cpu offloading

Browse files
Files changed (3) hide show
  1. acestep/gradio_ui.py +19 -2
  2. acestep/handler.py +335 -162
  3. test.py +3 -1
acestep/gradio_ui.py CHANGED
@@ -216,6 +216,16 @@ def create_generation_section(handler) -> dict:
216
  interactive=flash_attn_available,
217
  info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
218
  )
 
 
 
 
 
 
 
 
 
 
219
 
220
  init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
221
  init_status = gr.Textbox(label="Status", interactive=False, lines=3)
@@ -487,6 +497,7 @@ def create_generation_section(handler) -> dict:
487
  "lm_model_path": lm_model_path,
488
  "init_llm_checkbox": init_llm_checkbox,
489
  "use_flash_attention_checkbox": use_flash_attention_checkbox,
 
490
  "task_type": task_type,
491
  "instruction_display_gen": instruction_display_gen,
492
  "track_name": track_name,
@@ -655,9 +666,13 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
655
  )
656
 
657
  # Service initialization
658
- def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention):
659
  """Wrapper for service initialization, returns status and button state"""
660
- status, enable = handler.initialize_service(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention)
 
 
 
 
661
  return status, gr.update(interactive=enable)
662
 
663
  generation_section["init_btn"].click(
@@ -669,6 +684,8 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
669
  generation_section["init_llm_checkbox"],
670
  generation_section["lm_model_path"],
671
  generation_section["use_flash_attention_checkbox"],
 
 
672
  ],
673
  outputs=[generation_section["init_status"], generation_section["generate_btn"]]
674
  )
 
216
  interactive=flash_attn_available,
217
  info="Enable flash attention for faster inference (requires flash_attn package)" if flash_attn_available else "Flash attention not available (flash_attn package not installed)"
218
  )
219
+ offload_to_cpu_checkbox = gr.Checkbox(
220
+ label="Offload to CPU",
221
+ value=False,
222
+ info="Offload models to CPU when not in use to save GPU memory"
223
+ )
224
+ offload_dit_to_cpu_checkbox = gr.Checkbox(
225
+ label="Offload DiT to CPU",
226
+ value=False,
227
+ info="Offload DiT model to CPU when not in use (only effective if Offload to CPU is checked)"
228
+ )
229
 
230
  init_btn = gr.Button("Initialize Service", variant="primary", size="lg")
231
  init_status = gr.Textbox(label="Status", interactive=False, lines=3)
 
497
  "lm_model_path": lm_model_path,
498
  "init_llm_checkbox": init_llm_checkbox,
499
  "use_flash_attention_checkbox": use_flash_attention_checkbox,
500
+ "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
501
  "task_type": task_type,
502
  "instruction_display_gen": instruction_display_gen,
503
  "track_name": track_name,
 
666
  )
667
 
668
  # Service initialization
669
+ def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
670
  """Wrapper for service initialization, returns status and button state"""
671
+ status, enable = handler.initialize_service(
672
+ checkpoint, config_path, device, init_llm, lm_model_path,
673
+ use_flash_attention, compile_model=False,
674
+ offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
675
+ )
676
  return status, gr.update(interactive=enable)
677
 
678
  generation_section["init_btn"].click(
 
684
  generation_section["init_llm_checkbox"],
685
  generation_section["lm_model_path"],
686
  generation_section["use_flash_attention_checkbox"],
687
+ generation_section["offload_to_cpu_checkbox"],
688
+ generation_section["offload_dit_to_cpu_checkbox"],
689
  ],
690
  outputs=[generation_section["init_status"], generation_section["generate_btn"]]
691
  )
acestep/handler.py CHANGED
@@ -8,6 +8,7 @@ import tempfile
8
  import traceback
9
  import re
10
  import random
 
11
  from typing import Optional, Dict, Any, Tuple, List, Union
12
 
13
  import torch
@@ -81,6 +82,9 @@ class AceStepHandler:
81
  5: [8, 9, 11],
82
  6: [8]
83
  }
 
 
 
84
 
85
  def get_available_checkpoints(self) -> str:
86
  """Return project root directory path"""
@@ -146,6 +150,8 @@ class AceStepHandler:
146
  lm_model_path: str = "acestep-5Hz-lm-0.6B",
147
  use_flash_attention: bool = False,
148
  compile_model: bool = False,
 
 
149
  ) -> Tuple[str, bool]:
150
  """
151
  Initialize model service
@@ -158,6 +164,8 @@ class AceStepHandler:
158
  lm_model_path: 5Hz LM model path
159
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
160
  compile_model: Whether to use torch.compile to optimize the model
 
 
161
 
162
  Returns:
163
  (status_message, enable_generate_button)
@@ -167,6 +175,8 @@ class AceStepHandler:
167
  device = "cuda" if torch.cuda.is_available() else "cpu"
168
 
169
  self.device = device
 
 
170
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
171
  self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
172
 
@@ -211,7 +221,15 @@ class AceStepHandler:
211
  self.model.config._attn_implementation = attn_implementation
212
  self.config = self.model.config
213
  # Move model to device and set dtype
214
- self.model = self.model.to(device).to(self.dtype)
 
 
 
 
 
 
 
 
215
  self.model.eval()
216
 
217
  if compile_model:
@@ -221,7 +239,11 @@ class AceStepHandler:
221
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
222
  if os.path.exists(silence_latent_path):
223
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
224
- self.silence_latent = self.silence_latent.to(device).to(self.dtype)
 
 
 
 
225
  else:
226
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
227
  else:
@@ -233,7 +255,10 @@ class AceStepHandler:
233
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
234
  # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
235
  vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
236
- self.vae = self.vae.to(device).to(vae_dtype)
 
 
 
237
  self.vae.eval()
238
  else:
239
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
@@ -243,7 +268,10 @@ class AceStepHandler:
243
  if os.path.exists(text_encoder_path):
244
  self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
245
  self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
246
- self.text_encoder = self.text_encoder.to(device).to(self.dtype)
 
 
 
247
  self.text_encoder.eval()
248
  else:
249
  raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
@@ -252,12 +280,11 @@ class AceStepHandler:
252
  if init_llm:
253
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
254
  if os.path.exists(full_lm_model_path):
255
- if device == "cuda":
256
- status_msg = self._initialize_5hz_lm_cuda(full_lm_model_path)
257
- if not self.llm_initialized:
258
- return status_msg, False
259
- self.llm = AutoModel.from_pretrained(full_lm_model_path)
260
- self.llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path)
261
  else:
262
  # 5Hz LM path not found
263
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
@@ -275,7 +302,9 @@ class AceStepHandler:
275
  status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
276
  status_msg += f"Dtype: {self.dtype}\n"
277
  status_msg += f"Attention: {actual_attn}\n"
278
- status_msg += f"Compiled: {compile_model}"
 
 
279
 
280
  return status_msg, True
281
 
@@ -283,6 +312,86 @@ class AceStepHandler:
283
  error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
284
  return error_msg, False
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def import_dataset(self, dataset_type: str) -> str:
287
  """Import dataset (temporarily disabled)"""
288
  self.dataset_imported = False
@@ -314,36 +423,66 @@ class AceStepHandler:
314
  except Exception as e:
315
  return 0.9
316
 
317
- def _initialize_5hz_lm_cuda(self, model_path: str) -> str:
318
  """Initialize 5Hz LM model"""
319
  try:
320
- from nanovllm import LLM, SamplingParams
321
-
322
- if not torch.cuda.is_available():
323
- return "❌ CUDA is not available. Please check your GPU setup."
 
 
 
 
324
 
325
- current_device = torch.cuda.current_device()
326
- device_name = torch.cuda.get_device_name(current_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- torch.cuda.empty_cache()
329
- gpu_memory_utilization = self.get_gpu_memory_utilization(
330
- minimal_gpu=8,
331
- min_ratio=0.2,
332
- max_ratio=0.9
333
  )
334
 
335
- self.llm = LLM(
336
- model=model_path,
337
- enforce_eager=False,
338
- tensor_parallel_size=1,
339
- max_model_len=4096,
340
- gpu_memory_utilization=gpu_memory_utilization,
341
- )
342
- self.llm_tokenizer = self.llm.tokenizer
343
  self.llm_initialized = True
344
- return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
 
 
345
  except Exception as e:
346
  self.llm_initialized = False
 
347
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
348
  return error_msg
349
 
@@ -353,35 +492,54 @@ class AceStepHandler:
353
  return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
354
 
355
  try:
356
- from nanovllm import SamplingParams
357
-
358
- prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
359
-
360
- formatted_prompt = self.lm_tokenizer.apply_chat_template(
361
- [
362
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
363
- {"role": "user", "content": prompt}
364
- ],
365
- tokenize=False,
366
- add_generation_prompt=True,
367
- )
368
-
369
- sampling_params = SamplingParams(max_tokens=3072, temperature=temperature)
370
- outputs = self.llm.generate([formatted_prompt], sampling_params)
371
-
372
- if isinstance(outputs, list) and len(outputs) > 0:
373
- if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
374
- output_text = outputs[0].outputs[0].text
375
- elif hasattr(outputs[0], 'text'):
376
- output_text = outputs[0].text
 
 
 
 
 
377
  else:
378
- output_text = str(outputs[0])
379
- else:
380
- output_text = str(outputs)
381
-
382
- metadata, audio_codes = self.parse_lm_output(output_text)
383
- codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
384
- return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  except Exception as e:
387
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
@@ -495,24 +653,25 @@ class AceStepHandler:
495
  if len(code_ids) == 0:
496
  return None
497
 
498
- quantizer = self.model.tokenizer.quantizer
499
- detokenizer = self.model.detokenizer
500
-
501
- num_quantizers = getattr(quantizer, "num_quantizers", 1)
502
- indices = torch.tensor(code_ids, device=self.device, dtype=torch.long).unsqueeze(0) # [1, T_5Hz]
503
-
504
- # Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
505
- if indices.dim() == 2:
506
- indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
507
-
508
- # Get quantized representation from indices: [1, T_5Hz, dim]
509
- quantized = quantizer.get_output_from_indices(indices)
510
- if quantized.dtype != self.dtype:
511
- quantized = quantized.to(self.dtype)
512
-
513
- # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
514
- lm_hints_25hz = detokenizer(quantized)
515
- return lm_hints_25hz
 
516
 
517
  def _create_default_meta(self) -> str:
518
  """Create default metadata string."""
@@ -577,30 +736,31 @@ class AceStepHandler:
577
  if self.text_tokenizer is None or self.text_encoder is None:
578
  raise ValueError("Text encoder not initialized")
579
 
580
- # Tokenize
581
- text_inputs = self.text_tokenizer(
582
- text_prompt,
583
- padding="longest",
584
- truncation=True,
585
- max_length=256,
586
- return_tensors="pt",
587
- )
588
- text_input_ids = text_inputs.input_ids.to(self.device)
589
- text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
590
-
591
- # Encode
592
- with torch.no_grad():
593
- text_outputs = self.text_encoder(text_input_ids)
594
- if hasattr(text_outputs, 'last_hidden_state'):
595
- text_hidden_states = text_outputs.last_hidden_state
596
- elif isinstance(text_outputs, tuple):
597
- text_hidden_states = text_outputs[0]
598
- else:
599
- text_hidden_states = text_outputs
600
-
601
- text_hidden_states = text_hidden_states.to(self.dtype)
602
-
603
- return text_hidden_states, text_attention_mask
 
604
 
605
  def extract_caption_from_sft_format(self, caption: str) -> str:
606
  try:
@@ -1103,7 +1263,7 @@ class AceStepHandler:
1103
  if isinstance(refer_audio_list, list):
1104
  for idx, refer_audio in enumerate(refer_audio_list):
1105
  refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
1106
- elif isinstance(refer_audio_list, torch.tensor):
1107
  refer_audios[ii] = refer_audios[ii].to(self.device)
1108
 
1109
  if vocal_languages is None:
@@ -1131,35 +1291,37 @@ class AceStepHandler:
1131
  target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
1132
  if target_wavs.device != self.device:
1133
  target_wavs = target_wavs.to(self.device)
1134
- for i in range(batch_size):
1135
- code_hint = audio_code_hints[i]
1136
- # Prefer decoding from provided audio codes
1137
- if code_hint:
1138
- print(f"[generate_music] Decoding audio codes for item {i}...")
1139
- decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1140
- if decoded_latents is not None:
1141
- decoded_latents = decoded_latents.squeeze(0)
1142
- target_latents_list.append(decoded_latents)
1143
- latent_lengths.append(decoded_latents.shape[0])
1144
- # Create a silent wav matching the latent length for downstream scaling
1145
- frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920))
1146
- target_wavs_list[i] = torch.zeros(2, frames_from_codes)
1147
- continue
1148
- # Fallback to VAE encode from audio
1149
- current_wav = target_wavs_list[i].to(self.device).unsqueeze(0)
1150
- if self.is_silence(current_wav):
1151
- expected_latent_length = current_wav.shape[-1] // 1920
1152
- target_latent = self.silence_latent[0, :expected_latent_length, :]
1153
- else:
1154
- # Ensure input is in VAE's dtype
1155
- print(f"[generate_music] Encoding target audio to latents for item {i}...")
1156
- vae_input = current_wav.to(self.device).to(self.vae.dtype)
1157
- target_latent = self.vae.encode(vae_input).latent_dist.sample()
1158
- # Cast back to model dtype
1159
- target_latent = target_latent.to(self.dtype)
1160
- target_latent = target_latent.squeeze(0).transpose(0, 1)
1161
- target_latents_list.append(target_latent)
1162
- latent_lengths.append(target_latent.shape[0])
 
 
1163
 
1164
  # Pad target_wavs to consistent length for outputs
1165
  max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
@@ -1551,7 +1713,8 @@ class AceStepHandler:
1551
 
1552
  # step 2: refer_audio timbre
1553
  keys = batch["keys"]
1554
- refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"])
 
1555
  if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
1556
  refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
1557
 
@@ -1568,22 +1731,23 @@ class AceStepHandler:
1568
  text_inputs = batch["text_inputs"]
1569
 
1570
  print("[preprocess_batch] Inferring prompt embeddings...")
1571
- text_hidden_states = self.infer_text_embeddings(text_token_idss)
1572
- print("[preprocess_batch] Inferring lyric embeddings...")
1573
- lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
 
1574
 
1575
- is_covers = batch["is_covers"]
1576
-
1577
- # Get precomputed hints from batch if available
1578
- precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None)
1579
-
1580
- # Get non-cover text input ids and attention masks from batch if available
1581
- non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None)
1582
- non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1583
- non_cover_text_hidden_states = None
1584
- if non_cover_text_input_ids is not None:
1585
- print("[preprocess_batch] Inferring non-cover text embeddings...")
1586
- non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1587
 
1588
  return (
1589
  keys,
@@ -1811,7 +1975,8 @@ class AceStepHandler:
1811
  "cfg_interval_end": cfg_interval_end,
1812
  }
1813
  print("[service_generate] Generating audio...")
1814
- outputs = self.model.generate_audio(**generate_kwargs)
 
1815
  return outputs
1816
 
1817
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
@@ -1941,6 +2106,9 @@ class AceStepHandler:
1941
  if progress:
1942
  progress(0.05, desc="Preparing inputs...")
1943
  print("[generate_music] Preparing inputs...")
 
 
 
1944
 
1945
  # Caption and lyrics are optional - can be empty
1946
  # Use provided batch_size or default
@@ -2040,6 +2208,7 @@ class AceStepHandler:
2040
  print("[generate_music] Model generation completed. Decoding latents...")
2041
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2042
  time_costs = outputs["time_costs"]
 
2043
  print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2044
  print(f" - time_costs: {time_costs}")
2045
  if progress:
@@ -2049,23 +2218,27 @@ class AceStepHandler:
2049
  # Decode latents to audio
2050
  start_time = time.time()
2051
  with torch.no_grad():
2052
- # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
2053
- pred_latents_for_decode = pred_latents.transpose(1, 2)
2054
- # Ensure input is in VAE's dtype
2055
- pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
2056
-
2057
- if use_tiled_decode:
2058
- print("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
2059
- pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
2060
- else:
2061
- pred_wavs = self.vae.decode(pred_latents_for_decode).sample
2062
-
2063
- # Cast output to float32 for audio processing/saving
2064
- pred_wavs = pred_wavs.to(torch.float32)
 
2065
  end_time = time.time()
2066
  time_costs["vae_decode_time_cost"] = end_time - start_time
2067
  time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
2068
 
 
 
 
2069
  print("[generate_music] VAE decode completed. Saving audio files...")
2070
  if progress:
2071
  progress(0.9, desc="Saving audio files...")
 
8
  import traceback
9
  import re
10
  import random
11
+ from contextlib import contextmanager
12
  from typing import Optional, Dict, Any, Tuple, List, Union
13
 
14
  import torch
 
82
  5: [8, 9, 11],
83
  6: [8]
84
  }
85
+ self.offload_to_cpu = False
86
+ self.offload_dit_to_cpu = False
87
+ self.current_offload_cost = 0.0
88
 
89
  def get_available_checkpoints(self) -> str:
90
  """Return project root directory path"""
 
150
  lm_model_path: str = "acestep-5Hz-lm-0.6B",
151
  use_flash_attention: bool = False,
152
  compile_model: bool = False,
153
+ offload_to_cpu: bool = False,
154
+ offload_dit_to_cpu: bool = False,
155
  ) -> Tuple[str, bool]:
156
  """
157
  Initialize model service
 
164
  lm_model_path: 5Hz LM model path
165
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
166
  compile_model: Whether to use torch.compile to optimize the model
167
+ offload_to_cpu: Whether to offload models to CPU when not in use
168
+ offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True)
169
 
170
  Returns:
171
  (status_message, enable_generate_button)
 
175
  device = "cuda" if torch.cuda.is_available() else "cpu"
176
 
177
  self.device = device
178
+ self.offload_to_cpu = offload_to_cpu
179
+ self.offload_dit_to_cpu = offload_dit_to_cpu
180
  # Set dtype based on device: bfloat16 for cuda, float32 for cpu
181
  self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32
182
 
 
221
  self.model.config._attn_implementation = attn_implementation
222
  self.config = self.model.config
223
  # Move model to device and set dtype
224
+ if not self.offload_to_cpu:
225
+ self.model = self.model.to(device).to(self.dtype)
226
+ else:
227
+ # If offload_to_cpu is True, check if we should keep DiT on GPU
228
+ if not self.offload_dit_to_cpu:
229
+ logger.info(f"Keeping main model on {device} (persistent)")
230
+ self.model = self.model.to(device).to(self.dtype)
231
+ else:
232
+ self.model = self.model.to("cpu").to(self.dtype)
233
  self.model.eval()
234
 
235
  if compile_model:
 
239
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
240
  if os.path.exists(silence_latent_path):
241
  self.silence_latent = torch.load(silence_latent_path).transpose(1, 2)
242
+ # If DiT is on GPU, silence_latent should also be on GPU
243
+ if not self.offload_to_cpu or not self.offload_dit_to_cpu:
244
+ self.silence_latent = self.silence_latent.to(device).to(self.dtype)
245
+ else:
246
+ self.silence_latent = self.silence_latent.to("cpu").to(self.dtype)
247
  else:
248
  raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}")
249
  else:
 
255
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
256
  # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
257
  vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
258
+ if not self.offload_to_cpu:
259
+ self.vae = self.vae.to(device).to(vae_dtype)
260
+ else:
261
+ self.vae = self.vae.to("cpu").to(vae_dtype)
262
  self.vae.eval()
263
  else:
264
  raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}")
 
268
  if os.path.exists(text_encoder_path):
269
  self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path)
270
  self.text_encoder = AutoModel.from_pretrained(text_encoder_path)
271
+ if not self.offload_to_cpu:
272
+ self.text_encoder = self.text_encoder.to(device).to(self.dtype)
273
+ else:
274
+ self.text_encoder = self.text_encoder.to("cpu").to(self.dtype)
275
  self.text_encoder.eval()
276
  else:
277
  raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
 
280
  if init_llm:
281
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
282
  if os.path.exists(full_lm_model_path):
283
+ status_msg = self._initialize_5hz_lm(full_lm_model_path)
284
+ if not self.llm_initialized:
285
+ print(f"Error initializing 5Hz LM: {status_msg}")
286
+ return status_msg, False
287
+ print(status_msg)
 
288
  else:
289
  # 5Hz LM path not found
290
  return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
 
302
  status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
303
  status_msg += f"Dtype: {self.dtype}\n"
304
  status_msg += f"Attention: {actual_attn}\n"
305
+ status_msg += f"Compiled: {compile_model}\n"
306
+ status_msg += f"Offload to CPU: {self.offload_to_cpu}\n"
307
+ status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}"
308
 
309
  return status_msg, True
310
 
 
312
  error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
313
  return error_msg, False
314
 
315
+ @contextmanager
316
+ def _load_model_context(self, model_name: str):
317
+ """
318
+ Context manager to load a model to GPU and offload it back to CPU after use.
319
+
320
+ Args:
321
+ model_name: Name of the model to load ("text_encoder", "vae", "model", "llm")
322
+ """
323
+ if not self.offload_to_cpu:
324
+ yield
325
+ return
326
+
327
+ # If model is DiT ("model") and offload_dit_to_cpu is False, do not offload
328
+ if model_name == "model" and not self.offload_dit_to_cpu:
329
+ # Ensure it's on device if not already (should be handled by init, but safe to check)
330
+ model = getattr(self, model_name, None)
331
+ if model is not None:
332
+ # Check if model is on CPU, if so move to device (one-time move if it was somehow on CPU)
333
+ # We check the first parameter's device
334
+ try:
335
+ param = next(model.parameters())
336
+ if param.device.type == "cpu":
337
+ logger.info(f"Moving {model_name} to {self.device} (persistent)")
338
+ model.to(self.device).to(self.dtype)
339
+ if hasattr(self, "silence_latent"):
340
+ self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
341
+ except StopIteration:
342
+ pass
343
+ yield
344
+ return
345
+
346
+ # If model is LLM and using nanovllm, do not offload (it stays on GPU)
347
+ if model_name == "llm" and getattr(self, "llm_type", None) == "nanovllm":
348
+ yield
349
+ return
350
+
351
+ model = getattr(self, model_name, None)
352
+ if model is None:
353
+ yield
354
+ return
355
+
356
+ # Load to GPU
357
+ logger.info(f"Loading {model_name} to {self.device}")
358
+ start_time = time.time()
359
+ if model_name == "vae":
360
+ vae_dtype = torch.bfloat16 if self.device in ["cuda", "xpu"] else self.dtype
361
+ model.to(self.device).to(vae_dtype)
362
+ elif model_name == "llm" and hasattr(model, "to"):
363
+ # Special handling for nanovllm LLM which might have custom to() method or structure
364
+ # Assuming it has a .to() method based on our previous edits to nanovllm
365
+ model.to(self.device)
366
+ else:
367
+ model.to(self.device).to(self.dtype)
368
+
369
+ if model_name == "model" and hasattr(self, "silence_latent"):
370
+ self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
371
+
372
+ load_time = time.time() - start_time
373
+ self.current_offload_cost += load_time
374
+ logger.info(f"Loaded {model_name} to {self.device} in {load_time:.4f}s")
375
+
376
+ try:
377
+ yield
378
+ finally:
379
+ # Offload to CPU
380
+ logger.info(f"Offloading {model_name} to CPU")
381
+ start_time = time.time()
382
+ if model_name == "llm" and hasattr(model, "to"):
383
+ model.to("cpu")
384
+ else:
385
+ model.to("cpu")
386
+
387
+ if model_name == "model" and hasattr(self, "silence_latent"):
388
+ self.silence_latent = self.silence_latent.to("cpu")
389
+
390
+ torch.cuda.empty_cache()
391
+ offload_time = time.time() - start_time
392
+ self.current_offload_cost += offload_time
393
+ logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
394
+
395
  def import_dataset(self, dataset_type: str) -> str:
396
  """Import dataset (temporarily disabled)"""
397
  self.dataset_imported = False
 
423
  except Exception as e:
424
  return 0.9
425
 
426
+ def _initialize_5hz_lm(self, model_path: str) -> str:
427
  """Initialize 5Hz LM model"""
428
  try:
429
+ # Try to use nanovllm if on CUDA
430
+ use_nanovllm = False
431
+ if self.device == "cuda":
432
+ try:
433
+ from nanovllm import LLM, SamplingParams
434
+ use_nanovllm = True
435
+ except ImportError:
436
+ pass
437
 
438
+ if use_nanovllm:
439
+ try:
440
+ current_device = torch.cuda.current_device()
441
+ device_name = torch.cuda.get_device_name(current_device)
442
+
443
+ torch.cuda.empty_cache()
444
+ gpu_memory_utilization = self.get_gpu_memory_utilization(
445
+ minimal_gpu=8,
446
+ min_ratio=0.2,
447
+ max_ratio=0.9
448
+ )
449
+
450
+ self.llm = LLM(
451
+ model=model_path,
452
+ enforce_eager=False,
453
+ tensor_parallel_size=1,
454
+ max_model_len=4096,
455
+ gpu_memory_utilization=gpu_memory_utilization,
456
+ )
457
+ self.llm_tokenizer = self.llm.tokenizer
458
+ self.llm_initialized = True
459
+ self.llm_type = "nanovllm"
460
+ return f"✅ 5Hz LM initialized successfully (nanovllm)\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
461
+ except Exception as e:
462
+ logger.warning(f"nanovllm initialization failed: {e}, falling back to transformers")
463
+
464
+ # Fallback to transformers
465
+ from transformers import AutoModelForCausalLM
466
 
467
+ self.llm = AutoModelForCausalLM.from_pretrained(
468
+ model_path,
469
+ torch_dtype=self.dtype,
470
+ trust_remote_code=True
 
471
  )
472
 
473
+ if not self.offload_to_cpu:
474
+ self.llm.to(self.device)
475
+ else:
476
+ self.llm.to("cpu")
477
+
478
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
 
479
  self.llm_initialized = True
480
+ self.llm_type = "transformers"
481
+ return f"✅ 5Hz LM initialized successfully (transformers)\nModel: {model_path}\nDevice: {self.device}"
482
+
483
  except Exception as e:
484
  self.llm_initialized = False
485
+ self.llm_type = None
486
  error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
487
  return error_msg
488
 
 
492
  return {}, "", "❌ 5Hz LM not initialized. Please initialize it first."
493
 
494
  try:
495
+ with self._load_model_context("llm"):
496
+ prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
497
+
498
+ formatted_prompt = self.lm_tokenizer.apply_chat_template(
499
+ [
500
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
501
+ {"role": "user", "content": prompt}
502
+ ],
503
+ tokenize=False,
504
+ add_generation_prompt=True,
505
+ )
506
+
507
+ if getattr(self, "llm_type", "nanovllm") == "nanovllm":
508
+ from nanovllm import SamplingParams
509
+ sampling_params = SamplingParams(max_tokens=3072, temperature=temperature)
510
+ outputs = self.llm.generate([formatted_prompt], sampling_params)
511
+
512
+ if isinstance(outputs, list) and len(outputs) > 0:
513
+ if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
514
+ output_text = outputs[0].outputs[0].text
515
+ elif hasattr(outputs[0], 'text'):
516
+ output_text = outputs[0].text
517
+ else:
518
+ output_text = str(outputs[0])
519
+ else:
520
+ output_text = str(outputs)
521
  else:
522
+ # Transformers generation
523
+ inputs = self.llm_tokenizer(formatted_prompt, return_tensors="pt").to(self.llm.device)
524
+
525
+ # Generate
526
+ with torch.no_grad():
527
+ outputs = self.llm.generate(
528
+ **inputs,
529
+ max_new_tokens=3072,
530
+ temperature=temperature,
531
+ do_sample=True,
532
+ pad_token_id=self.llm_tokenizer.pad_token_id,
533
+ eos_token_id=self.llm_tokenizer.eos_token_id
534
+ )
535
+
536
+ # Decode
537
+ generated_ids = outputs[0][inputs.input_ids.shape[1]:]
538
+ output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
539
+
540
+ metadata, audio_codes = self.parse_lm_output(output_text)
541
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
542
+ return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
543
 
544
  except Exception as e:
545
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
653
  if len(code_ids) == 0:
654
  return None
655
 
656
+ with self._load_model_context("model"):
657
+ quantizer = self.model.tokenizer.quantizer
658
+ detokenizer = self.model.detokenizer
659
+
660
+ num_quantizers = getattr(quantizer, "num_quantizers", 1)
661
+ indices = torch.tensor(code_ids, device=self.device, dtype=torch.long).unsqueeze(0) # [1, T_5Hz]
662
+
663
+ # Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
664
+ if indices.dim() == 2:
665
+ indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
666
+
667
+ # Get quantized representation from indices: [1, T_5Hz, dim]
668
+ quantized = quantizer.get_output_from_indices(indices)
669
+ if quantized.dtype != self.dtype:
670
+ quantized = quantized.to(self.dtype)
671
+
672
+ # Detokenize to 25Hz: [1, T_5Hz, dim] -> [1, T_25Hz, dim]
673
+ lm_hints_25hz = detokenizer(quantized)
674
+ return lm_hints_25hz
675
 
676
  def _create_default_meta(self) -> str:
677
  """Create default metadata string."""
 
736
  if self.text_tokenizer is None or self.text_encoder is None:
737
  raise ValueError("Text encoder not initialized")
738
 
739
+ with self._load_model_context("text_encoder"):
740
+ # Tokenize
741
+ text_inputs = self.text_tokenizer(
742
+ text_prompt,
743
+ padding="longest",
744
+ truncation=True,
745
+ max_length=256,
746
+ return_tensors="pt",
747
+ )
748
+ text_input_ids = text_inputs.input_ids.to(self.device)
749
+ text_attention_mask = text_inputs.attention_mask.to(self.device).bool()
750
+
751
+ # Encode
752
+ with torch.no_grad():
753
+ text_outputs = self.text_encoder(text_input_ids)
754
+ if hasattr(text_outputs, 'last_hidden_state'):
755
+ text_hidden_states = text_outputs.last_hidden_state
756
+ elif isinstance(text_outputs, tuple):
757
+ text_hidden_states = text_outputs[0]
758
+ else:
759
+ text_hidden_states = text_outputs
760
+
761
+ text_hidden_states = text_hidden_states.to(self.dtype)
762
+
763
+ return text_hidden_states, text_attention_mask
764
 
765
  def extract_caption_from_sft_format(self, caption: str) -> str:
766
  try:
 
1263
  if isinstance(refer_audio_list, list):
1264
  for idx, refer_audio in enumerate(refer_audio_list):
1265
  refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16)
1266
+ elif isinstance(refer_audio_list, torch.Tensor):
1267
  refer_audios[ii] = refer_audios[ii].to(self.device)
1268
 
1269
  if vocal_languages is None:
 
1291
  target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)]
1292
  if target_wavs.device != self.device:
1293
  target_wavs = target_wavs.to(self.device)
1294
+
1295
+ with self._load_model_context("vae"):
1296
+ for i in range(batch_size):
1297
+ code_hint = audio_code_hints[i]
1298
+ # Prefer decoding from provided audio codes
1299
+ if code_hint:
1300
+ print(f"[generate_music] Decoding audio codes for item {i}...")
1301
+ decoded_latents = self._decode_audio_codes_to_latents(code_hint)
1302
+ if decoded_latents is not None:
1303
+ decoded_latents = decoded_latents.squeeze(0)
1304
+ target_latents_list.append(decoded_latents)
1305
+ latent_lengths.append(decoded_latents.shape[0])
1306
+ # Create a silent wav matching the latent length for downstream scaling
1307
+ frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920))
1308
+ target_wavs_list[i] = torch.zeros(2, frames_from_codes)
1309
+ continue
1310
+ # Fallback to VAE encode from audio
1311
+ current_wav = target_wavs_list[i].to(self.device).unsqueeze(0)
1312
+ if self.is_silence(current_wav):
1313
+ expected_latent_length = current_wav.shape[-1] // 1920
1314
+ target_latent = self.silence_latent[0, :expected_latent_length, :]
1315
+ else:
1316
+ # Ensure input is in VAE's dtype
1317
+ print(f"[generate_music] Encoding target audio to latents for item {i}...")
1318
+ vae_input = current_wav.to(self.device).to(self.vae.dtype)
1319
+ target_latent = self.vae.encode(vae_input).latent_dist.sample()
1320
+ # Cast back to model dtype
1321
+ target_latent = target_latent.to(self.dtype)
1322
+ target_latent = target_latent.squeeze(0).transpose(0, 1)
1323
+ target_latents_list.append(target_latent)
1324
+ latent_lengths.append(target_latent.shape[0])
1325
 
1326
  # Pad target_wavs to consistent length for outputs
1327
  max_target_frames = max(wav.shape[-1] for wav in target_wavs_list)
 
1713
 
1714
  # step 2: refer_audio timbre
1715
  keys = batch["keys"]
1716
+ with self._load_model_context("vae"):
1717
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"])
1718
  if refer_audio_acoustic_hidden_states_packed.dtype != dtype:
1719
  refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype)
1720
 
 
1731
  text_inputs = batch["text_inputs"]
1732
 
1733
  print("[preprocess_batch] Inferring prompt embeddings...")
1734
+ with self._load_model_context("text_encoder"):
1735
+ text_hidden_states = self.infer_text_embeddings(text_token_idss)
1736
+ print("[preprocess_batch] Inferring lyric embeddings...")
1737
+ lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss)
1738
 
1739
+ is_covers = batch["is_covers"]
1740
+
1741
+ # Get precomputed hints from batch if available
1742
+ precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None)
1743
+
1744
+ # Get non-cover text input ids and attention masks from batch if available
1745
+ non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None)
1746
+ non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None)
1747
+ non_cover_text_hidden_states = None
1748
+ if non_cover_text_input_ids is not None:
1749
+ print("[preprocess_batch] Inferring non-cover text embeddings...")
1750
+ non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids)
1751
 
1752
  return (
1753
  keys,
 
1975
  "cfg_interval_end": cfg_interval_end,
1976
  }
1977
  print("[service_generate] Generating audio...")
1978
+ with self._load_model_context("model"):
1979
+ outputs = self.model.generate_audio(**generate_kwargs)
1980
  return outputs
1981
 
1982
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
 
2106
  if progress:
2107
  progress(0.05, desc="Preparing inputs...")
2108
  print("[generate_music] Preparing inputs...")
2109
+
2110
+ # Reset offload cost
2111
+ self.current_offload_cost = 0.0
2112
 
2113
  # Caption and lyrics are optional - can be empty
2114
  # Use provided batch_size or default
 
2208
  print("[generate_music] Model generation completed. Decoding latents...")
2209
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2210
  time_costs = outputs["time_costs"]
2211
+ time_costs["offload_time_cost"] = self.current_offload_cost
2212
  print(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2213
  print(f" - time_costs: {time_costs}")
2214
  if progress:
 
2218
  # Decode latents to audio
2219
  start_time = time.time()
2220
  with torch.no_grad():
2221
+ with self._load_model_context("vae"):
2222
+ # Transpose for VAE decode: [batch, latent_length, latent_dim] -> [batch, latent_dim, latent_length]
2223
+ pred_latents_for_decode = pred_latents.transpose(1, 2)
2224
+ # Ensure input is in VAE's dtype
2225
+ pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype)
2226
+
2227
+ if use_tiled_decode:
2228
+ print("[generate_music] Using tiled VAE decode to reduce VRAM usage...")
2229
+ pred_wavs = self.tiled_decode(pred_latents_for_decode) # [batch, channels, samples]
2230
+ else:
2231
+ pred_wavs = self.vae.decode(pred_latents_for_decode).sample
2232
+
2233
+ # Cast output to float32 for audio processing/saving
2234
+ pred_wavs = pred_wavs.to(torch.float32)
2235
  end_time = time.time()
2236
  time_costs["vae_decode_time_cost"] = end_time - start_time
2237
  time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"]
2238
 
2239
+ # Update offload cost one last time to include VAE offloading
2240
+ time_costs["offload_time_cost"] = self.current_offload_cost
2241
+
2242
  print("[generate_music] VAE decode completed. Saving audio files...")
2243
  if progress:
2244
  progress(0.9, desc="Saving audio files...")
test.py CHANGED
@@ -41,7 +41,9 @@ def main():
41
  device=device,
42
  init_llm=True,
43
  use_flash_attention=False, # Default in UI
44
- compile_model=True
 
 
45
  )
46
 
47
  if not enabled:
 
41
  device=device,
42
  init_llm=True,
43
  use_flash_attention=False, # Default in UI
44
+ compile_model=False,
45
+ offload_to_cpu=True,
46
+ offload_dit_to_cpu=False, # Keep DiT on GPU
47
  )
48
 
49
  if not enabled: