Spaces:
Running
on
A100
Running
on
A100
| import os | |
| import sys | |
| import torch | |
| import shutil | |
| import resource | |
| from acestep.handler import AceStepHandler | |
| def main(): | |
| print("Initializing AceStepHandler...") | |
| handler = AceStepHandler() | |
| # Find checkpoints | |
| checkpoints = handler.get_available_checkpoints() | |
| if checkpoints: | |
| project_root = checkpoints[0] | |
| else: | |
| # Fallback | |
| current_file = os.path.abspath(__file__) | |
| project_root = os.path.join(os.path.dirname(current_file), "checkpoints") | |
| print(f"Project root (checkpoints dir): {project_root}") | |
| # Find models | |
| models = handler.get_available_acestep_v15_models() | |
| if not models: | |
| print("No models found. Using default 'acestep-v15-turbo'.") | |
| model_name = "./acestep-v15-turbo" | |
| else: | |
| model_name = models[0] | |
| print(f"Found models: {models}") | |
| print(f"Using model: {model_name}") | |
| # Initialize service | |
| device = "xpu" | |
| print(f"Using device: {device}") | |
| use_llm = False | |
| status, enabled = handler.initialize_service( | |
| project_root=project_root, | |
| config_path=model_name, | |
| device=device, | |
| init_llm=use_llm, | |
| use_flash_attention=False, # Default in UI | |
| compile_model=True, | |
| offload_to_cpu=True, | |
| offload_dit_to_cpu=False, # Keep DiT on GPU | |
| quantization="fp8_weight_only", # Enable FP8 weight-only quantization | |
| ) | |
| if not enabled: | |
| print(f"Error initializing service: {status}") | |
| return | |
| print(status) | |
| print("Service initialized successfully.") | |
| # Prepare inputs | |
| captions = "A soft pop arrangement led by light, fingerpicked guitar sets a gentle foundation, Airy keys subtly fill the background, while delicate percussion adds warmth, The sweet female voice floats above, blending naturally with minimal harmonies in the chorus for an intimate, uplifting sound" | |
| lyrics = """[Intro] | |
| [Verse 1] | |
| 风吹动那年仲夏 | |
| 翻开谁青涩喧哗 | |
| 白枫书架 | |
| 第七页码 | |
| [Verse 2] | |
| 珍藏谁的长发 | |
| 星夜似手中花洒 | |
| 淋湿旧忆木篱笆 | |
| 木槿花下 | |
| 天蓝发夹 | |
| 她默认了他 | |
| [Bridge] | |
| 时光将青春的薄荷红蜡 | |
| 匆匆地融化 | |
| 她却沉入人海再无应答 | |
| 隐没在天涯 | |
| [Chorus] | |
| 燕子在窗前飞掠 | |
| 寻不到的花被季节带回 | |
| 拧不干的思念如月 | |
| 初恋颜色才能够描绘 | |
| 木槿在窗外落雪 | |
| 倾泻道别的滋味 | |
| 闭上眼听见微咸的泪水 | |
| 到后来才知那故梦珍贵 | |
| [Outro]""" | |
| seeds = "320145306, 1514681811" | |
| print("Starting generation...") | |
| # Generate hints using 5Hz LLM | |
| if use_llm: | |
| print("Generating hints using 5Hz LLM...") | |
| lm_temperature = 0.6 | |
| metadata, audio_codes, lm_status = handler.generate_with_5hz_lm(captions, lyrics, lm_temperature) | |
| print(f"5Hz LLM Status: {lm_status}") | |
| print(f"Generated Metadata: {metadata}") | |
| print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...") | |
| else: | |
| print("Skipping 5Hz LLM generation...") | |
| metadata = {} | |
| audio_codes = None | |
| lm_status = "Skipped" | |
| # Use generated metadata if available | |
| bpm = metadata.get('bpm', 90) | |
| if bpm == "N/A" or bpm == "": | |
| bpm = 90 | |
| else: | |
| try: | |
| bpm = int(float(bpm)) | |
| except: | |
| bpm = 90 | |
| key_scale = metadata.get('keyscale', metadata.get('key_scale', "A major")) | |
| if key_scale == "N/A": | |
| key_scale = "A major" | |
| time_signature = metadata.get('timesignature', metadata.get('time_signature', "4")) | |
| if time_signature == "N/A": | |
| time_signature = "4" | |
| audio_duration = metadata.get('duration', 120) | |
| if audio_duration == "N/A": | |
| audio_duration = 120 | |
| else: | |
| try: | |
| audio_duration = float(audio_duration) | |
| except: | |
| audio_duration = 120 | |
| print(f"Using parameters: BPM={bpm}, Key={key_scale}, Time Sig={time_signature}, Duration={audio_duration}") | |
| # Reset peak memory stats | |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
| torch.xpu.reset_peak_memory_stats() | |
| elif torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| # Call generate_music | |
| results = handler.generate_music( | |
| captions=captions, | |
| lyrics=lyrics, | |
| bpm=bpm, | |
| key_scale=key_scale, | |
| time_signature=time_signature, | |
| vocal_language="zh", | |
| inference_steps=8, | |
| guidance_scale=7.0, | |
| use_random_seed=False, | |
| seed=seeds, | |
| audio_duration=audio_duration, | |
| batch_size=1, | |
| task_type="text2music", | |
| cfg_interval_start=0.0, | |
| cfg_interval_end=0.95, | |
| audio_format="wav", | |
| use_tiled_decode=True, | |
| audio_code_string=audio_codes, | |
| ) | |
| # Unpack results | |
| (audio1, audio2, saved_files, info, status_msg, seed_val, | |
| align_score1, align_text1, align_plot1, | |
| align_score2, align_text2, align_plot2) = results | |
| print("\nGeneration Complete!") | |
| # Print memory stats | |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): | |
| peak_vram = torch.xpu.max_memory_allocated() / (1024 ** 3) | |
| print(f"Peak VRAM usage: {peak_vram:.2f} GB") | |
| elif torch.cuda.is_available(): | |
| peak_vram = torch.cuda.max_memory_allocated() / (1024 ** 3) | |
| print(f"Peak VRAM usage: {peak_vram:.2f} GB") | |
| peak_ram = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024 ** 2) | |
| print(f"Peak RAM usage: {peak_ram:.2f} GB") | |
| print(f"Status: {status_msg}") | |
| print(f"Info: {info}") | |
| print(f"Seeds used: {seed_val}") | |
| print(f"Saved files: {saved_files}") | |
| # Copy files | |
| for f in saved_files: | |
| if os.path.exists(f): | |
| dst = os.path.basename(f) | |
| shutil.copy(f, dst) | |
| print(f"Saved output to: {os.path.abspath(dst)}") | |
| if __name__ == "__main__": | |
| main() | |