File size: 5,938 Bytes
84d50ff
 
 
 
497ca57
84d50ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bdda7b
 
84d50ff
 
 
 
1bdda7b
84d50ff
1bdda7b
1d23edb
 
1daf6b4
84d50ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497ca57
1bdda7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497ca57
 
 
 
 
84d50ff
 
 
 
 
1bdda7b
 
 
84d50ff
 
 
 
 
1bdda7b
84d50ff
 
 
 
497ca57
 
1bdda7b
84d50ff
 
 
 
 
 
 
 
497ca57
 
 
 
 
 
 
 
 
 
 
84d50ff
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import os
import sys
import torch
import shutil
import resource
from acestep.handler import AceStepHandler


def main():
    print("Initializing AceStepHandler...")
    handler = AceStepHandler()
    
    # Find checkpoints
    checkpoints = handler.get_available_checkpoints()
    if checkpoints:
        project_root = checkpoints[0]
    else:
        # Fallback
        current_file = os.path.abspath(__file__)
        project_root = os.path.join(os.path.dirname(current_file), "checkpoints")
        
    print(f"Project root (checkpoints dir): {project_root}")
    
    # Find models
    models = handler.get_available_acestep_v15_models()
    if not models:
        print("No models found. Using default 'acestep-v15-turbo'.")
        model_name = "./acestep-v15-turbo"
    else:
        model_name = models[0]
        print(f"Found models: {models}")
        print(f"Using model: {model_name}")
        
    # Initialize service
    device = "xpu"
    print(f"Using device: {device}")
    
    use_llm = False

    status, enabled = handler.initialize_service(
        project_root=project_root,
        config_path=model_name,
        device=device,
        init_llm=use_llm,
        use_flash_attention=False, # Default in UI
        compile_model=True,
        offload_to_cpu=True,
        offload_dit_to_cpu=False, # Keep DiT on GPU
        quantization="fp8_weight_only", # Enable FP8 weight-only quantization
    )
    
    if not enabled:
        print(f"Error initializing service: {status}")
        return
    
    print(status)
    print("Service initialized successfully.")
    
    # Prepare inputs
    captions = "A soft pop arrangement led by light, fingerpicked guitar sets a gentle foundation, Airy keys subtly fill the background, while delicate percussion adds warmth, The sweet female voice floats above, blending naturally with minimal harmonies in the chorus for an intimate, uplifting sound"
    
    lyrics = """[Intro]

[Verse 1]
风吹动那年仲夏
翻开谁青涩喧哗
白枫书架
第七页码

[Verse 2]
珍藏谁的长发
星夜似手中花洒
淋湿旧忆木篱笆
木槿花下
天蓝发夹
她默认了他

[Bridge]
时光将青春的薄荷红蜡
匆匆地融化
她却沉入人海再无应答
隐没在天涯

[Chorus]
燕子在窗前飞掠
寻不到的花被季节带回
拧不干的思念如月
初恋颜色才能够描绘

木槿在窗外落雪
倾泻道别的滋味
闭上眼听见微咸的泪水
到后来才知那故梦珍贵

[Outro]"""

    seeds = "320145306, 1514681811"
    
    print("Starting generation...")

    # Generate hints using 5Hz LLM
    if use_llm:
        print("Generating hints using 5Hz LLM...")
        lm_temperature = 0.6
        metadata, audio_codes, lm_status = handler.generate_with_5hz_lm(captions, lyrics, lm_temperature)
        print(f"5Hz LLM Status: {lm_status}")
        print(f"Generated Metadata: {metadata}")
        print(f"Generated Audio Codes (first 50 chars): {audio_codes[:50]}...")
    else:
        print("Skipping 5Hz LLM generation...")
        metadata = {}
        audio_codes = None
        lm_status = "Skipped"

    # Use generated metadata if available
    bpm = metadata.get('bpm', 90)
    if bpm == "N/A" or bpm == "":
        bpm = 90
    else:
        try:
            bpm = int(float(bpm))
        except:
            bpm = 90
            
    key_scale = metadata.get('keyscale', metadata.get('key_scale', "A major"))
    if key_scale == "N/A":
        key_scale = "A major"
        
    time_signature = metadata.get('timesignature', metadata.get('time_signature', "4"))
    if time_signature == "N/A":
        time_signature = "4"
        
    audio_duration = metadata.get('duration', 120)
    if audio_duration == "N/A":
        audio_duration = 120
    else:
        try:
            audio_duration = float(audio_duration)
        except:
            audio_duration = 120

    print(f"Using parameters: BPM={bpm}, Key={key_scale}, Time Sig={time_signature}, Duration={audio_duration}")

    # Reset peak memory stats
    if hasattr(torch, 'xpu') and torch.xpu.is_available():
        torch.xpu.reset_peak_memory_stats()
    elif torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    
    # Call generate_music
    results = handler.generate_music(
        captions=captions,
        lyrics=lyrics,
        bpm=bpm,
        key_scale=key_scale,
        time_signature=time_signature,
        vocal_language="zh",
        inference_steps=8,
        guidance_scale=7.0,
        use_random_seed=False,
        seed=seeds,
        audio_duration=audio_duration,
        batch_size=1,
        task_type="text2music",
        cfg_interval_start=0.0,
        cfg_interval_end=0.95,
        audio_format="wav",
        use_tiled_decode=True,
        audio_code_string=audio_codes,
    )
    
    # Unpack results
    (audio1, audio2, saved_files, info, status_msg, seed_val, 
     align_score1, align_text1, align_plot1, 
     align_score2, align_text2, align_plot2) = results
     
    print("\nGeneration Complete!")

    # Print memory stats
    if hasattr(torch, 'xpu') and torch.xpu.is_available():
        peak_vram = torch.xpu.max_memory_allocated() / (1024 ** 3)
        print(f"Peak VRAM usage: {peak_vram:.2f} GB")
    elif torch.cuda.is_available():
        peak_vram = torch.cuda.max_memory_allocated() / (1024 ** 3)
        print(f"Peak VRAM usage: {peak_vram:.2f} GB")
        
    peak_ram = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / (1024 ** 2)
    print(f"Peak RAM usage: {peak_ram:.2f} GB")
    print(f"Status: {status_msg}")
    print(f"Info: {info}")
    print(f"Seeds used: {seed_val}")
    print(f"Saved files: {saved_files}")
    
    # Copy files
    for f in saved_files:
        if os.path.exists(f):
            dst = os.path.basename(f)
            shutil.copy(f, dst)
            print(f"Saved output to: {os.path.abspath(dst)}")

if __name__ == "__main__":
    main()