ChuxiJ commited on
Commit
388b5af
·
1 Parent(s): 03f73c6

fix bugs for save

Browse files
acestep/audio_utils.py CHANGED
@@ -89,15 +89,12 @@ class AudioSaver:
89
  try:
90
  if format == "mp3":
91
  # MP3 uses ffmpeg backend
92
- from torchaudio.io import CodecConfig
93
- config = CodecConfig(bit_rate=192000, compression_level=1)
94
  torchaudio.save(
95
  str(output_path),
96
  audio_tensor,
97
  sample_rate,
98
  channels_first=True,
99
  backend='ffmpeg',
100
- compression=config,
101
  )
102
  elif format in ["flac", "wav"]:
103
  # FLAC and WAV use soundfile backend (fastest)
@@ -106,7 +103,7 @@ class AudioSaver:
106
  audio_tensor,
107
  sample_rate,
108
  channels_first=True,
109
- backend='ffmpeg',
110
  )
111
  else:
112
  # Other formats use default backend
 
89
  try:
90
  if format == "mp3":
91
  # MP3 uses ffmpeg backend
 
 
92
  torchaudio.save(
93
  str(output_path),
94
  audio_tensor,
95
  sample_rate,
96
  channels_first=True,
97
  backend='ffmpeg',
 
98
  )
99
  elif format in ["flac", "wav"]:
100
  # FLAC and WAV use soundfile backend (fastest)
 
103
  audio_tensor,
104
  sample_rate,
105
  channels_first=True,
106
+ backend='soundfile',
107
  )
108
  else:
109
  # Other formats use default backend
acestep/gradio_ui/events/__init__.py CHANGED
@@ -254,48 +254,84 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
254
  ]
255
  )
256
 
257
- # Save buttons for audio 1 and 2
258
- for btn_idx, btn_key in [(1, "save_btn_1"), (2, "save_btn_2")]:
259
- results_section[btn_key].click(
260
- fn=res_h.save_audio_and_metadata,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  inputs=[
262
  results_section[f"generated_audio_{btn_idx}"],
263
- generation_section["task_type"],
264
- generation_section["captions"],
265
- generation_section["lyrics"],
266
- generation_section["vocal_language"],
267
- generation_section["bpm"],
268
- generation_section["key_scale"],
269
- generation_section["time_signature"],
270
- generation_section["audio_duration"],
271
- generation_section["batch_size_input"],
272
- generation_section["inference_steps"],
273
- generation_section["guidance_scale"],
274
- generation_section["seed"],
275
- generation_section["random_seed_checkbox"],
276
- generation_section["use_adg"],
277
- generation_section["cfg_interval_start"],
278
- generation_section["cfg_interval_end"],
279
- generation_section["audio_format"],
280
- generation_section["lm_temperature"],
281
- generation_section["lm_cfg_scale"],
282
- generation_section["lm_top_k"],
283
- generation_section["lm_top_p"],
284
- generation_section["lm_negative_prompt"],
285
- generation_section["use_cot_caption"],
286
- generation_section["use_cot_language"],
287
- generation_section["audio_cover_strength"],
288
- generation_section["think_checkbox"],
289
- generation_section["text2music_audio_code_string"],
290
- generation_section["repainting_start"],
291
- generation_section["repainting_end"],
292
- generation_section["track_name"],
293
- generation_section["complete_track_classes"],
294
- results_section["lm_metadata_state"],
295
  ],
296
- outputs=[gr.File(label="Download Package", visible=False)]
297
- )
298
-
299
  # ========== Send to SRC Handlers ==========
300
  for btn_idx in range(1, 9):
301
  results_section[f"send_to_src_btn_{btn_idx}"].click(
 
254
  ]
255
  )
256
 
257
+ # Save buttons for all 8 audio outputs
258
+ download_existing_js = """(current_audio, batch_files) => {
259
+ // Debug: print what the input actually is
260
+ console.log("👉 [Debug] Current Audio Input:", current_audio);
261
+
262
+ // 1. Safety check
263
+ if (!current_audio) {
264
+ console.warn("⚠️ No audio selected or audio is empty.");
265
+ return;
266
+ }
267
+ if (!batch_files || !Array.isArray(batch_files)) {
268
+ console.warn("⚠️ Batch file list is empty/not ready.");
269
+ return;
270
+ }
271
+
272
+ // 2. Smartly extract path string
273
+ let pathString = "";
274
+
275
+ if (typeof current_audio === "string") {
276
+ // Case A: direct path string received
277
+ pathString = current_audio;
278
+ } else if (typeof current_audio === "object") {
279
+ // Case B: an object is received, try common properties
280
+ // Gradio file objects usually have path, url, or name
281
+ pathString = current_audio.path || current_audio.name || current_audio.url || "";
282
+ }
283
+
284
+ if (!pathString) {
285
+ console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
286
+ return;
287
+ }
288
+
289
+ // 3. Extract Key (UUID)
290
+ // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
291
+ let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
292
+ let key = filename.split('.')[0]; // get UUID without extension
293
+
294
+ console.log(`🔑 Key extracted: ${key}`);
295
+
296
+ // 4. Find matching file(s) in the list
297
+ let targets = batch_files.filter(f => {
298
+ // Also extract names from batch_files objects
299
+ // f usually contains name (backend path) and orig_name (download name)
300
+ const fPath = f.name || f.path || "";
301
+ return fPath.includes(key);
302
+ });
303
+
304
+ if (targets.length === 0) {
305
+ console.warn("❌ No matching files found in batch list for key:", key);
306
+ alert("Batch list does not contain this file yet. Please wait for generation to finish.");
307
+ return;
308
+ }
309
+
310
+ // 5. Trigger download(s)
311
+ console.log(`🎯 Found ${targets.length} files to download.`);
312
+ targets.forEach((f, index) => {
313
+ setTimeout(() => {
314
+ const a = document.createElement('a');
315
+ // Prefer url (frontend-accessible link), otherwise try data
316
+ a.href = f.url || f.data;
317
+ a.download = f.orig_name || "download";
318
+ a.style.display = 'none';
319
+ document.body.appendChild(a);
320
+ a.click();
321
+ document.body.removeChild(a);
322
+ }, index * 1000); // 300ms interval to avoid browser blocking
323
+ });
324
+ }
325
+ """
326
+ for btn_idx in range(1, 9):
327
+ results_section[f"save_btn_{btn_idx}"].click(
328
+ fn=None,
329
  inputs=[
330
  results_section[f"generated_audio_{btn_idx}"],
331
+ results_section["generated_audio_batch"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  ],
333
+ js=download_existing_js # Run the above JS
334
+ )
 
335
  # ========== Send to SRC Handlers ==========
336
  for btn_idx in range(1, 9):
337
  results_section[f"send_to_src_btn_{btn_idx}"].click(
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -180,99 +180,6 @@ def update_navigation_buttons(current_batch, total_batches):
180
  can_go_next = current_batch < total_batches - 1
181
  return can_go_previous, can_go_next
182
 
183
-
184
- def save_audio_and_metadata(
185
- audio_path, task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
186
- batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
187
- use_adg, cfg_interval_start, cfg_interval_end, audio_format,
188
- lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
189
- use_cot_caption, use_cot_language, audio_cover_strength,
190
- think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
191
- track_name, complete_track_classes, lm_metadata
192
- ):
193
- """Save audio file and its metadata as a zip package"""
194
- if audio_path is None:
195
- gr.Warning(t("messages.no_audio_to_save"))
196
- return None
197
-
198
- try:
199
- # Create metadata dictionary
200
- metadata = {
201
- "saved_at": datetime.datetime.now().isoformat(),
202
- "task_type": task_type,
203
- "caption": captions or "",
204
- "lyrics": lyrics or "",
205
- "vocal_language": vocal_language,
206
- "bpm": bpm if bpm is not None else None,
207
- "keyscale": key_scale or "",
208
- "timesignature": time_signature or "",
209
- "duration": audio_duration if audio_duration is not None else -1,
210
- "batch_size": batch_size_input,
211
- "inference_steps": inference_steps,
212
- "guidance_scale": guidance_scale,
213
- "seed": seed,
214
- "random_seed": False, # Disable random seed for reproducibility
215
- "use_adg": use_adg,
216
- "cfg_interval_start": cfg_interval_start,
217
- "cfg_interval_end": cfg_interval_end,
218
- "audio_format": audio_format,
219
- "lm_temperature": lm_temperature,
220
- "lm_cfg_scale": lm_cfg_scale,
221
- "lm_top_k": lm_top_k,
222
- "lm_top_p": lm_top_p,
223
- "lm_negative_prompt": lm_negative_prompt,
224
- "use_cot_caption": use_cot_caption,
225
- "use_cot_language": use_cot_language,
226
- "audio_cover_strength": audio_cover_strength,
227
- "think": think_checkbox,
228
- "audio_codes": text2music_audio_code_string or "",
229
- "repainting_start": repainting_start,
230
- "repainting_end": repainting_end,
231
- "track_name": track_name,
232
- "complete_track_classes": complete_track_classes or [],
233
- }
234
-
235
- # Add LM-generated metadata if available
236
- if lm_metadata:
237
- metadata["lm_generated_metadata"] = lm_metadata
238
-
239
- # Generate timestamp and base name
240
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
241
-
242
- # Extract audio filename extension
243
- audio_ext = os.path.splitext(audio_path)[1]
244
-
245
- # Create temporary directory for packaging
246
- temp_dir = tempfile.mkdtemp()
247
-
248
- # Save JSON metadata
249
- json_path = os.path.join(temp_dir, f"metadata_{timestamp}.json")
250
- with open(json_path, 'w', encoding='utf-8') as f:
251
- json.dump(metadata, f, indent=2, ensure_ascii=False)
252
-
253
- # Copy audio file
254
- audio_copy_path = os.path.join(temp_dir, f"audio_{timestamp}{audio_ext}")
255
- shutil.copy2(audio_path, audio_copy_path)
256
-
257
- # Create zip file
258
- zip_path = os.path.join(tempfile.gettempdir(), f"music_package_{timestamp}.zip")
259
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
260
- zipf.write(audio_copy_path, os.path.basename(audio_copy_path))
261
- zipf.write(json_path, os.path.basename(json_path))
262
-
263
- # Clean up temp directory
264
- shutil.rmtree(temp_dir)
265
-
266
- gr.Info(t("messages.save_success", filename=os.path.basename(zip_path)))
267
- return zip_path
268
-
269
- except Exception as e:
270
- gr.Warning(t("messages.save_failed", error=str(e)))
271
- import traceback
272
- traceback.print_exc()
273
- return None
274
-
275
-
276
  def send_audio_to_src_with_metadata(audio_file, lm_metadata):
277
  """Send generated audio file to src_audio input and populate metadata fields
278
 
@@ -455,16 +362,17 @@ def generate_with_progress(
455
  align_plot_2 = None
456
  updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
457
 
 
 
 
 
 
 
 
 
 
458
  if not result.success:
459
- # Build generation_info string for error case
460
- generation_info = _build_generation_info(
461
- lm_metadata=lm_generated_metadata,
462
- time_costs=time_costs,
463
- seed_value=seed_value_for_ui,
464
- inference_steps=inference_steps,
465
- num_audios=0,
466
- )
467
- yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 25
468
  return
469
 
470
  audios = result.audios
@@ -480,8 +388,11 @@ def generate_with_progress(
480
  json_path = os.path.join(temp_dir, f"{key}.json")
481
  audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
482
  save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
 
 
483
  audio_outputs[i] = audio_path
484
  all_audio_paths.append(audio_path)
 
485
 
486
  code_str = audio_params.get("audio_codes", "")
487
  final_codes_list[i] = code_str
 
180
  can_go_next = current_batch < total_batches - 1
181
  return can_go_previous, can_go_next
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def send_audio_to_src_with_metadata(audio_file, lm_metadata):
184
  """Send generated audio file to src_audio input and populate metadata fields
185
 
 
362
  align_plot_2 = None
363
  updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
364
 
365
+ # Build initial generation_info (will be updated with post-processing times at the end)
366
+ generation_info = _build_generation_info(
367
+ lm_metadata=lm_generated_metadata,
368
+ time_costs=time_costs,
369
+ seed_value=seed_value_for_ui,
370
+ inference_steps=inference_steps,
371
+ num_audios=len(result.audios) if result.success else 0,
372
+ )
373
+
374
  if not result.success:
375
+ yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 26
 
 
 
 
 
 
 
 
376
  return
377
 
378
  audios = result.audios
 
388
  json_path = os.path.join(temp_dir, f"{key}.json")
389
  audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
390
  save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
391
+ with open(json_path, 'w', encoding='utf-8') as f:
392
+ json.dump(audio_params, f, indent=2, ensure_ascii=False)
393
  audio_outputs[i] = audio_path
394
  all_audio_paths.append(audio_path)
395
+ all_audio_paths.append(json_path)
396
 
397
  code_str = audio_params.get("audio_codes", "")
398
  final_codes_list[i] = code_str