Spaces:
Running
on
A100
Running
on
A100
fix bugs for save
Browse files
acestep/audio_utils.py
CHANGED
|
@@ -89,15 +89,12 @@ class AudioSaver:
|
|
| 89 |
try:
|
| 90 |
if format == "mp3":
|
| 91 |
# MP3 uses ffmpeg backend
|
| 92 |
-
from torchaudio.io import CodecConfig
|
| 93 |
-
config = CodecConfig(bit_rate=192000, compression_level=1)
|
| 94 |
torchaudio.save(
|
| 95 |
str(output_path),
|
| 96 |
audio_tensor,
|
| 97 |
sample_rate,
|
| 98 |
channels_first=True,
|
| 99 |
backend='ffmpeg',
|
| 100 |
-
compression=config,
|
| 101 |
)
|
| 102 |
elif format in ["flac", "wav"]:
|
| 103 |
# FLAC and WAV use soundfile backend (fastest)
|
|
@@ -106,7 +103,7 @@ class AudioSaver:
|
|
| 106 |
audio_tensor,
|
| 107 |
sample_rate,
|
| 108 |
channels_first=True,
|
| 109 |
-
backend='
|
| 110 |
)
|
| 111 |
else:
|
| 112 |
# Other formats use default backend
|
|
|
|
| 89 |
try:
|
| 90 |
if format == "mp3":
|
| 91 |
# MP3 uses ffmpeg backend
|
|
|
|
|
|
|
| 92 |
torchaudio.save(
|
| 93 |
str(output_path),
|
| 94 |
audio_tensor,
|
| 95 |
sample_rate,
|
| 96 |
channels_first=True,
|
| 97 |
backend='ffmpeg',
|
|
|
|
| 98 |
)
|
| 99 |
elif format in ["flac", "wav"]:
|
| 100 |
# FLAC and WAV use soundfile backend (fastest)
|
|
|
|
| 103 |
audio_tensor,
|
| 104 |
sample_rate,
|
| 105 |
channels_first=True,
|
| 106 |
+
backend='soundfile',
|
| 107 |
)
|
| 108 |
else:
|
| 109 |
# Other formats use default backend
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -254,48 +254,84 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 254 |
]
|
| 255 |
)
|
| 256 |
|
| 257 |
-
# Save buttons for
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
inputs=[
|
| 262 |
results_section[f"generated_audio_{btn_idx}"],
|
| 263 |
-
|
| 264 |
-
generation_section["captions"],
|
| 265 |
-
generation_section["lyrics"],
|
| 266 |
-
generation_section["vocal_language"],
|
| 267 |
-
generation_section["bpm"],
|
| 268 |
-
generation_section["key_scale"],
|
| 269 |
-
generation_section["time_signature"],
|
| 270 |
-
generation_section["audio_duration"],
|
| 271 |
-
generation_section["batch_size_input"],
|
| 272 |
-
generation_section["inference_steps"],
|
| 273 |
-
generation_section["guidance_scale"],
|
| 274 |
-
generation_section["seed"],
|
| 275 |
-
generation_section["random_seed_checkbox"],
|
| 276 |
-
generation_section["use_adg"],
|
| 277 |
-
generation_section["cfg_interval_start"],
|
| 278 |
-
generation_section["cfg_interval_end"],
|
| 279 |
-
generation_section["audio_format"],
|
| 280 |
-
generation_section["lm_temperature"],
|
| 281 |
-
generation_section["lm_cfg_scale"],
|
| 282 |
-
generation_section["lm_top_k"],
|
| 283 |
-
generation_section["lm_top_p"],
|
| 284 |
-
generation_section["lm_negative_prompt"],
|
| 285 |
-
generation_section["use_cot_caption"],
|
| 286 |
-
generation_section["use_cot_language"],
|
| 287 |
-
generation_section["audio_cover_strength"],
|
| 288 |
-
generation_section["think_checkbox"],
|
| 289 |
-
generation_section["text2music_audio_code_string"],
|
| 290 |
-
generation_section["repainting_start"],
|
| 291 |
-
generation_section["repainting_end"],
|
| 292 |
-
generation_section["track_name"],
|
| 293 |
-
generation_section["complete_track_classes"],
|
| 294 |
-
results_section["lm_metadata_state"],
|
| 295 |
],
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
# ========== Send to SRC Handlers ==========
|
| 300 |
for btn_idx in range(1, 9):
|
| 301 |
results_section[f"send_to_src_btn_{btn_idx}"].click(
|
|
|
|
| 254 |
]
|
| 255 |
)
|
| 256 |
|
| 257 |
+
# Save buttons for all 8 audio outputs
|
| 258 |
+
download_existing_js = """(current_audio, batch_files) => {
|
| 259 |
+
// Debug: print what the input actually is
|
| 260 |
+
console.log("👉 [Debug] Current Audio Input:", current_audio);
|
| 261 |
+
|
| 262 |
+
// 1. Safety check
|
| 263 |
+
if (!current_audio) {
|
| 264 |
+
console.warn("⚠️ No audio selected or audio is empty.");
|
| 265 |
+
return;
|
| 266 |
+
}
|
| 267 |
+
if (!batch_files || !Array.isArray(batch_files)) {
|
| 268 |
+
console.warn("⚠️ Batch file list is empty/not ready.");
|
| 269 |
+
return;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// 2. Smartly extract path string
|
| 273 |
+
let pathString = "";
|
| 274 |
+
|
| 275 |
+
if (typeof current_audio === "string") {
|
| 276 |
+
// Case A: direct path string received
|
| 277 |
+
pathString = current_audio;
|
| 278 |
+
} else if (typeof current_audio === "object") {
|
| 279 |
+
// Case B: an object is received, try common properties
|
| 280 |
+
// Gradio file objects usually have path, url, or name
|
| 281 |
+
pathString = current_audio.path || current_audio.name || current_audio.url || "";
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
if (!pathString) {
|
| 285 |
+
console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
|
| 286 |
+
return;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
// 3. Extract Key (UUID)
|
| 290 |
+
// Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
|
| 291 |
+
let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
|
| 292 |
+
let key = filename.split('.')[0]; // get UUID without extension
|
| 293 |
+
|
| 294 |
+
console.log(`🔑 Key extracted: ${key}`);
|
| 295 |
+
|
| 296 |
+
// 4. Find matching file(s) in the list
|
| 297 |
+
let targets = batch_files.filter(f => {
|
| 298 |
+
// Also extract names from batch_files objects
|
| 299 |
+
// f usually contains name (backend path) and orig_name (download name)
|
| 300 |
+
const fPath = f.name || f.path || "";
|
| 301 |
+
return fPath.includes(key);
|
| 302 |
+
});
|
| 303 |
+
|
| 304 |
+
if (targets.length === 0) {
|
| 305 |
+
console.warn("❌ No matching files found in batch list for key:", key);
|
| 306 |
+
alert("Batch list does not contain this file yet. Please wait for generation to finish.");
|
| 307 |
+
return;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
// 5. Trigger download(s)
|
| 311 |
+
console.log(`🎯 Found ${targets.length} files to download.`);
|
| 312 |
+
targets.forEach((f, index) => {
|
| 313 |
+
setTimeout(() => {
|
| 314 |
+
const a = document.createElement('a');
|
| 315 |
+
// Prefer url (frontend-accessible link), otherwise try data
|
| 316 |
+
a.href = f.url || f.data;
|
| 317 |
+
a.download = f.orig_name || "download";
|
| 318 |
+
a.style.display = 'none';
|
| 319 |
+
document.body.appendChild(a);
|
| 320 |
+
a.click();
|
| 321 |
+
document.body.removeChild(a);
|
| 322 |
+
}, index * 1000); // 300ms interval to avoid browser blocking
|
| 323 |
+
});
|
| 324 |
+
}
|
| 325 |
+
"""
|
| 326 |
+
for btn_idx in range(1, 9):
|
| 327 |
+
results_section[f"save_btn_{btn_idx}"].click(
|
| 328 |
+
fn=None,
|
| 329 |
inputs=[
|
| 330 |
results_section[f"generated_audio_{btn_idx}"],
|
| 331 |
+
results_section["generated_audio_batch"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
],
|
| 333 |
+
js=download_existing_js # Run the above JS
|
| 334 |
+
)
|
|
|
|
| 335 |
# ========== Send to SRC Handlers ==========
|
| 336 |
for btn_idx in range(1, 9):
|
| 337 |
results_section[f"send_to_src_btn_{btn_idx}"].click(
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -180,99 +180,6 @@ def update_navigation_buttons(current_batch, total_batches):
|
|
| 180 |
can_go_next = current_batch < total_batches - 1
|
| 181 |
return can_go_previous, can_go_next
|
| 182 |
|
| 183 |
-
|
| 184 |
-
def save_audio_and_metadata(
|
| 185 |
-
audio_path, task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
|
| 186 |
-
batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
|
| 187 |
-
use_adg, cfg_interval_start, cfg_interval_end, audio_format,
|
| 188 |
-
lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 189 |
-
use_cot_caption, use_cot_language, audio_cover_strength,
|
| 190 |
-
think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
|
| 191 |
-
track_name, complete_track_classes, lm_metadata
|
| 192 |
-
):
|
| 193 |
-
"""Save audio file and its metadata as a zip package"""
|
| 194 |
-
if audio_path is None:
|
| 195 |
-
gr.Warning(t("messages.no_audio_to_save"))
|
| 196 |
-
return None
|
| 197 |
-
|
| 198 |
-
try:
|
| 199 |
-
# Create metadata dictionary
|
| 200 |
-
metadata = {
|
| 201 |
-
"saved_at": datetime.datetime.now().isoformat(),
|
| 202 |
-
"task_type": task_type,
|
| 203 |
-
"caption": captions or "",
|
| 204 |
-
"lyrics": lyrics or "",
|
| 205 |
-
"vocal_language": vocal_language,
|
| 206 |
-
"bpm": bpm if bpm is not None else None,
|
| 207 |
-
"keyscale": key_scale or "",
|
| 208 |
-
"timesignature": time_signature or "",
|
| 209 |
-
"duration": audio_duration if audio_duration is not None else -1,
|
| 210 |
-
"batch_size": batch_size_input,
|
| 211 |
-
"inference_steps": inference_steps,
|
| 212 |
-
"guidance_scale": guidance_scale,
|
| 213 |
-
"seed": seed,
|
| 214 |
-
"random_seed": False, # Disable random seed for reproducibility
|
| 215 |
-
"use_adg": use_adg,
|
| 216 |
-
"cfg_interval_start": cfg_interval_start,
|
| 217 |
-
"cfg_interval_end": cfg_interval_end,
|
| 218 |
-
"audio_format": audio_format,
|
| 219 |
-
"lm_temperature": lm_temperature,
|
| 220 |
-
"lm_cfg_scale": lm_cfg_scale,
|
| 221 |
-
"lm_top_k": lm_top_k,
|
| 222 |
-
"lm_top_p": lm_top_p,
|
| 223 |
-
"lm_negative_prompt": lm_negative_prompt,
|
| 224 |
-
"use_cot_caption": use_cot_caption,
|
| 225 |
-
"use_cot_language": use_cot_language,
|
| 226 |
-
"audio_cover_strength": audio_cover_strength,
|
| 227 |
-
"think": think_checkbox,
|
| 228 |
-
"audio_codes": text2music_audio_code_string or "",
|
| 229 |
-
"repainting_start": repainting_start,
|
| 230 |
-
"repainting_end": repainting_end,
|
| 231 |
-
"track_name": track_name,
|
| 232 |
-
"complete_track_classes": complete_track_classes or [],
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
# Add LM-generated metadata if available
|
| 236 |
-
if lm_metadata:
|
| 237 |
-
metadata["lm_generated_metadata"] = lm_metadata
|
| 238 |
-
|
| 239 |
-
# Generate timestamp and base name
|
| 240 |
-
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 241 |
-
|
| 242 |
-
# Extract audio filename extension
|
| 243 |
-
audio_ext = os.path.splitext(audio_path)[1]
|
| 244 |
-
|
| 245 |
-
# Create temporary directory for packaging
|
| 246 |
-
temp_dir = tempfile.mkdtemp()
|
| 247 |
-
|
| 248 |
-
# Save JSON metadata
|
| 249 |
-
json_path = os.path.join(temp_dir, f"metadata_{timestamp}.json")
|
| 250 |
-
with open(json_path, 'w', encoding='utf-8') as f:
|
| 251 |
-
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
| 252 |
-
|
| 253 |
-
# Copy audio file
|
| 254 |
-
audio_copy_path = os.path.join(temp_dir, f"audio_{timestamp}{audio_ext}")
|
| 255 |
-
shutil.copy2(audio_path, audio_copy_path)
|
| 256 |
-
|
| 257 |
-
# Create zip file
|
| 258 |
-
zip_path = os.path.join(tempfile.gettempdir(), f"music_package_{timestamp}.zip")
|
| 259 |
-
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 260 |
-
zipf.write(audio_copy_path, os.path.basename(audio_copy_path))
|
| 261 |
-
zipf.write(json_path, os.path.basename(json_path))
|
| 262 |
-
|
| 263 |
-
# Clean up temp directory
|
| 264 |
-
shutil.rmtree(temp_dir)
|
| 265 |
-
|
| 266 |
-
gr.Info(t("messages.save_success", filename=os.path.basename(zip_path)))
|
| 267 |
-
return zip_path
|
| 268 |
-
|
| 269 |
-
except Exception as e:
|
| 270 |
-
gr.Warning(t("messages.save_failed", error=str(e)))
|
| 271 |
-
import traceback
|
| 272 |
-
traceback.print_exc()
|
| 273 |
-
return None
|
| 274 |
-
|
| 275 |
-
|
| 276 |
def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
| 277 |
"""Send generated audio file to src_audio input and populate metadata fields
|
| 278 |
|
|
@@ -455,16 +362,17 @@ def generate_with_progress(
|
|
| 455 |
align_plot_2 = None
|
| 456 |
updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
if not result.success:
|
| 459 |
-
|
| 460 |
-
generation_info = _build_generation_info(
|
| 461 |
-
lm_metadata=lm_generated_metadata,
|
| 462 |
-
time_costs=time_costs,
|
| 463 |
-
seed_value=seed_value_for_ui,
|
| 464 |
-
inference_steps=inference_steps,
|
| 465 |
-
num_audios=0,
|
| 466 |
-
)
|
| 467 |
-
yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 25
|
| 468 |
return
|
| 469 |
|
| 470 |
audios = result.audios
|
|
@@ -480,8 +388,11 @@ def generate_with_progress(
|
|
| 480 |
json_path = os.path.join(temp_dir, f"{key}.json")
|
| 481 |
audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
|
| 482 |
save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
|
|
|
|
|
|
|
| 483 |
audio_outputs[i] = audio_path
|
| 484 |
all_audio_paths.append(audio_path)
|
|
|
|
| 485 |
|
| 486 |
code_str = audio_params.get("audio_codes", "")
|
| 487 |
final_codes_list[i] = code_str
|
|
|
|
| 180 |
can_go_next = current_batch < total_batches - 1
|
| 181 |
return can_go_previous, can_go_next
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
def send_audio_to_src_with_metadata(audio_file, lm_metadata):
|
| 184 |
"""Send generated audio file to src_audio input and populate metadata fields
|
| 185 |
|
|
|
|
| 362 |
align_plot_2 = None
|
| 363 |
updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
|
| 364 |
|
| 365 |
+
# Build initial generation_info (will be updated with post-processing times at the end)
|
| 366 |
+
generation_info = _build_generation_info(
|
| 367 |
+
lm_metadata=lm_generated_metadata,
|
| 368 |
+
time_costs=time_costs,
|
| 369 |
+
seed_value=seed_value_for_ui,
|
| 370 |
+
inference_steps=inference_steps,
|
| 371 |
+
num_audios=len(result.audios) if result.success else 0,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
if not result.success:
|
| 375 |
+
yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 26
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
return
|
| 377 |
|
| 378 |
audios = result.audios
|
|
|
|
| 388 |
json_path = os.path.join(temp_dir, f"{key}.json")
|
| 389 |
audio_path = os.path.join(temp_dir, f"{key}.{audio_format}")
|
| 390 |
save_audio(audio_data=audio_tensor, output_path=audio_path, sample_rate=sample_rate, format=audio_format, channels_first=True)
|
| 391 |
+
with open(json_path, 'w', encoding='utf-8') as f:
|
| 392 |
+
json.dump(audio_params, f, indent=2, ensure_ascii=False)
|
| 393 |
audio_outputs[i] = audio_path
|
| 394 |
all_audio_paths.append(audio_path)
|
| 395 |
+
all_audio_paths.append(json_path)
|
| 396 |
|
| 397 |
code_str = audio_params.get("audio_codes", "")
|
| 398 |
final_codes_list[i] = code_str
|