Spaces:
Running
on
A100
Running
on
A100
File size: 5,938 Bytes
84d50ff 497ca57 84d50ff 1bdda7b 84d50ff 1bdda7b 84d50ff 1bdda7b 1d23edb 1daf6b4 84d50ff 497ca57 1bdda7b 497ca57 84d50ff 1bdda7b 84d50ff 1bdda7b 84d50ff 497ca57 1bdda7b 84d50ff 497ca57 84d50ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import os
import sys
import torch
import shutil
import resource
from acestep.handler import AceStepHandler
def main():
print("Initializing AceStepHandler...")
handler = AceStepHandler()
# Find checkpoints
checkpoints = handler.get_available_checkpoints()
if checkpoints:
project_root = checkpoints[0]
else:
# Fallback
current_file = os.path.abspath(__file__)
project_root = os.path.join(os.path.dirname(current_file), "checkpoints")
print(f"Project root (checkpoints dir): {project_root}")
# Find models
models = handler.get_available_acestep_v15_models()
if not models:
print("No models found. Using default 'acestep-v15-turbo'.")
model_name = "./acestep-v15-turbo"
else:
model_name = models[0]
print(f"Found models: {models}")
print(f"Using model: {model_name}")
# Initialize service
device = "xpu"
print(f"Using device: {device}")
use_llm = False
status, enabled = handler.initialize_service(
project_root=project_root,
config_path=model_name,
device=device,
init_llm=use_llm,
use_flash_attention=False, # Default in UI
compile_model=True,
offload_to_cpu=True,
offload_dit_to_cpu=False, # Keep DiT on GPU
quantization="fp8_weight_only", # Enable FP8 weight-only quantization
)
if not enabled:
print(f"Error initializing service: {status}")
return
print(status)
print("Service initialized successfully.")
# Prepare inputs
captions = "A soft pop arrangement led by light, fingerpicked guitar sets a gentle foundation, Airy keys subtly fill the background, while delicate percussion adds warmth, The sweet female voice floats above, blending naturally with minimal harmonies in the chorus for an intimate, uplifting sound"
lyrics = """[Intro]
[Verse 1]
风吹动那年仲夏
翻开谁青涩喧哗
白枫书架
第七页码
[Verse 2]
珍藏谁的长发
星夜似手中花洒
淋湿旧忆木篱笆
木槿花下
天蓝发夹
她默认了他
[Bridge]
时光将青春的薄荷红蜡
匆匆地融化
她却沉入人海再无应答
隐没在天涯
[Chorus]
燕子在窗前飞掠
寻不到的花被季节带回
拧不干的思念如月
初恋颜色才能够描绘
木槿在窗外落雪
倾泻道别的滋味
闭上眼听见微咸的泪水
到后来才知那故梦珍贵
[Outro]"""
seeds = "320145306, 1514681811"
print("Starting generation...")
# Generate hints using 5Hz LLM
if use_llm:
print("Generating hints using 5Hz LLM...")
lm_temperature = 0.6
metadata, audio_codes, lm_status = handler.generate_with_5hz_lm(captions, lyrics, lm_temperature)
print(f"5Hz LLM Status: {lm_status}")
print(f"Generated Metadata: {metadata}")
print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
else:
print("Skipping 5Hz LLM generation...")
metadata = {}
audio_codes = None
lm_status = "Skipped"
# Use generated metadata if available
bpm = metadata.get('bpm', 90)
if bpm == "N/A" or bpm == "":
bpm = 90
else:
try:
bpm = int(float(bpm))
except:
bpm = 90
key_scale = metadata.get('keyscale', metadata.get('key_scale', "A major"))
if key_scale == "N/A":
key_scale = "A major"
time_signature = metadata.get('timesignature', metadata.get('time_signature', "4"))
if time_signature == "N/A":
time_signature = "4"
audio_duration = metadata.get('duration', 120)
if audio_duration == "N/A":
audio_duration = 120
else:
try:
audio_duration = float(audio_duration)
except:
audio_duration = 120
print(f"Using parameters: BPM={bpm}, Key={key_scale}, Time Sig={time_signature}, Duration={audio_duration}")
# Reset peak memory stats
if hasattr(torch, 'xpu') and torch.xpu.is_available():
torch.xpu.reset_peak_memory_stats()
elif torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Call generate_music
results = handler.generate_music(
captions=captions,
lyrics=lyrics,
bpm=bpm,
key_scale=key_scale,
time_signature=time_signature,
vocal_language="zh",
inference_steps=8,
guidance_scale=7.0,
use_random_seed=False,
seed=seeds,
audio_duration=audio_duration,
batch_size=1,
task_type="text2music",
cfg_interval_start=0.0,
cfg_interval_end=0.95,
audio_format="wav",
use_tiled_decode=True,
audio_code_string=audio_codes,
)
# Unpack results
(audio1, audio2, saved_files, info, status_msg, seed_val,
align_score1, align_text1, align_plot1,
align_score2, align_text2, align_plot2) = results
print("\nGeneration Complete!")
# Print memory stats
if hasattr(torch, 'xpu') and torch.xpu.is_available():
peak_vram = torch.xpu.max_memory_allocated() / (1024 ** 3)
print(f"Peak VRAM usage: {peak_vram:.2f} GB")
elif torch.cuda.is_available():
peak_vram = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f"Peak VRAM usage: {peak_vram:.2f} GB")
peak_ram = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024 ** 2)
print(f"Peak RAM usage: {peak_ram:.2f} GB")
print(f"Status: {status_msg}")
print(f"Info: {info}")
print(f"Seeds used: {seed_val}")
print(f"Saved files: {saved_files}")
# Copy files
for f in saved_files:
if os.path.exists(f):
dst = os.path.basename(f)
shutil.copy(f, dst)
print(f"Saved output to: {os.path.abspath(dst)}")
if __name__ == "__main__":
main()
|