File size: 4,560 Bytes
2644f3e
 
 
 
 
 
 
 
 
 
 
544833b
2644f3e
 
f9e2d84
2644f3e
 
 
57d225d
2644f3e
 
 
d658154
2644f3e
 
 
 
 
f6b176e
2644f3e
 
d658154
 
 
2644f3e
b000a9b
2644f3e
 
 
f9e2d84
 
 
 
 
 
2644f3e
f9e2d84
 
2644f3e
f9e2d84
 
2644f3e
 
f9e2d84
 
2644f3e
f9e2d84
2644f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ca82a
2644f3e
 
 
544833b
2644f3e
d658154
 
 
57d225d
 
 
 
 
d658154
 
 
 
 
 
 
 
 
2644f3e
57d225d
 
2644f3e
 
 
 
 
 
d658154
2644f3e
 
 
 
 
 
e55b921
3779445
e55b921
3779445
2644f3e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys


sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')

import torch

import json
import numpy as np
from omegaconf import OmegaConf

from codeclm.models import builders
from codeclm.models import CodecLM

from separator import Separator
from generate import check_language_by_text


class LeVoInference(torch.nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()

        torch.backends.cudnn.enabled = False 
        OmegaConf.register_new_resolver("eval", lambda x: eval(x))
        OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
        OmegaConf.register_new_resolver("get_fname", lambda: 'default')
        OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))

        cfg_path = os.path.join(ckpt_path, 'config.yaml')
        pt_path = os.path.join(ckpt_path, 'model.pt')

        self.cfg = OmegaConf.load(cfg_path)
        self.cfg.mode = 'inference'
        self.max_duration = self.cfg.max_dur

        # Define model or load pretrained model
        audiolm = builders.get_lm_model(self.cfg, version='v1.5')
        checkpoint = torch.load(pt_path, map_location='cpu')
        audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')}
        audiolm.load_state_dict(audiolm_state_dict, strict=False)
        audiolm = audiolm.eval()
        audiolm = audiolm.cuda().to(torch.float16)

        audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg)
        audio_tokenizer = audio_tokenizer.eval()

        seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg)
        seperate_tokenizer = seperate_tokenizer.eval()

        self.model = CodecLM(name = "tmp",
            lm = audiolm,
            audiotokenizer = audio_tokenizer,
            max_duration = self.max_duration,
            seperate_tokenizer = seperate_tokenizer,
        )
        self.separator = Separator()


        self.default_params = dict(
            cfg_coef = 1.5,
            temperature = 1.0,
            top_k = 50,
            top_p = 0.0,
            record_tokens = True,
            record_window = 50,
            extend_stride = 5,
            duration = self.max_duration,
        )

        self.model.set_generation_params(**self.default_params)

    def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, gen_type: str = "mixed", params = dict()):
        params = {**self.default_params, **params}
        self.model.set_generation_params(**params)

        if prompt_audio_path is not None and os.path.exists(prompt_audio_path):
            pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
            melody_is_wav = True
        elif genre is not None and auto_prompt_path is not None:
            auto_prompt = torch.load(auto_prompt_path)
            if genre == 'Auto':
                lang = check_language_by_text(lyric)
                prompt_token = auto_prompt['Auto'][lang][np.random.randint(0, len(auto_prompt['Auto'][lang]))]
            else:
                prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
            pmt_wav = prompt_token[:,[0],:]
            vocal_wav = prompt_token[:,[1],:]
            bgm_wav = prompt_token[:,[2],:]
            melody_is_wav = False
        else:
            pmt_wav = None
            vocal_wav = None
            bgm_wav = None
            melody_is_wav = True

        description = description if description else '.'
        description = '[Musicality-very-high]' + ', ' + description
        generate_inp = {
            'lyrics': [lyric.replace("  ", " ")],
            'descriptions': [description],
            'melody_wavs': pmt_wav,
            'vocal_wavs': vocal_wav,
            'bgm_wavs': bgm_wav,
            'melody_is_wav': melody_is_wav,
        }

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            tokens = self.model.generate(**generate_inp, return_tokens=True)
            
        with torch.no_grad():
            if melody_is_wav:
                wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav, gen_type=gen_type)
            else:
                wav_seperate = self.model.generate_audio(tokens, gen_type=gen_type)

        return wav_seperate[0]