Spaces:
Sleeping
Sleeping
| import logging | |
| import math | |
| import os.path | |
| import re | |
| from typing import List | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from time import time as ttime | |
| from contants import config | |
| from gpt_sovits.AR.models.t2s_lightning_module import Text2SemanticLightningModule | |
| from gpt_sovits.module.mel_processing import spectrogram_torch | |
| from gpt_sovits.module.models import SynthesizerTrn | |
| from gpt_sovits.utils import DictToAttrRecursive | |
| from gpt_sovits.text import cleaned_text_to_sequence | |
| from gpt_sovits.text.cleaner import clean_text | |
| from utils.classify_language import classify_language | |
| from utils.data_utils import check_is_none | |
| from utils.sentence import split_languages, sentence_split | |
| splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", } | |
| class GPT_SoVITS: | |
| def __init__(self, sovits_path, gpt_path, device, **kwargs): | |
| self.sovits_path = sovits_path | |
| self.gpt_path = gpt_path | |
| self.hz = config.gpt_sovits_config.hz | |
| self.sampling_rate = None | |
| self.device = device | |
| self.model_handler = None | |
| self.is_half = config.gpt_sovits_config.is_half | |
| self.np_dtype = np.float16 if self.is_half else np.float32 | |
| self.torch_dtype = torch.float16 if self.is_half else torch.float32 | |
| self.speakers = None | |
| self.lang = ["zh", "ja", "en"] | |
| self.flash_attn_enabled = True | |
| self.prompt_cache: dict = { | |
| "ref_audio_path": None, | |
| "prompt_semantic": None, | |
| "refer_spepc": None, | |
| "prompt_text": None, | |
| "prompt_lang": None, | |
| "phones": None, | |
| "bert_features": None, | |
| "norm_text": None, | |
| } | |
| def load_model(self, model_handler): | |
| self.model_handler = model_handler | |
| self.load_sovits(self.sovits_path) | |
| self.load_gpt(self.gpt_path) | |
| self.tokenizer, self.bert_model = self.model_handler.get_bert_model("CHINESE_ROBERTA_WWM_EXT_LARGE") | |
| self.ssl_model = self.model_handler.get_ssl_model() | |
| def load_weight(self, saved_state_dict, model): | |
| if hasattr(model, 'module'): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| try: | |
| new_state_dict[k] = saved_state_dict[k] | |
| except: | |
| # logging.info(f"{k} is not in the checkpoint") | |
| new_state_dict[k] = v | |
| if hasattr(model, 'module'): | |
| model.module.load_state_dict(new_state_dict) | |
| else: | |
| model.load_state_dict(new_state_dict) | |
| def load_sovits(self, sovits_path): | |
| # self.n_semantic = 1024 | |
| logging.info(f"Loaded checkpoint '{sovits_path}'") | |
| dict_s2 = torch.load(sovits_path, map_location=self.device) | |
| self.hps = dict_s2["config"] | |
| self.hps = DictToAttrRecursive(self.hps) | |
| self.hps.model.semantic_frame_rate = "25hz" | |
| # self.speakers = [self.hps.get("name")] # 从模型配置中获取名字 | |
| self.speakers = [os.path.basename(os.path.dirname(self.sovits_path))] # 用模型文件夹作为名字 | |
| self.vq_model = SynthesizerTrn( | |
| self.hps.data.filter_length // 2 + 1, | |
| self.hps.train.segment_size // self.hps.data.hop_length, | |
| n_speakers=self.hps.data.n_speakers, | |
| **self.hps.model).to(self.device) | |
| if config.gpt_sovits_config.is_half: | |
| self.vq_model = self.vq_model.half() | |
| self.vq_model.eval() | |
| self.sampling_rate = self.hps.data.sampling_rate | |
| self.load_weight(dict_s2['weight'], self.vq_model) | |
| def load_gpt(self, gpt_path): | |
| logging.info(f"Loaded checkpoint '{gpt_path}'") | |
| dict_s1 = torch.load(gpt_path, map_location=self.device) | |
| self.gpt_config = dict_s1["config"] | |
| self.max_sec = self.gpt_config.get("data").get("max_sec") | |
| self.t2s_model = Text2SemanticLightningModule(self.gpt_config, "****", is_train=False, | |
| flash_attn_enabled=self.flash_attn_enabled).to( | |
| self.device) | |
| self.load_weight(dict_s1['weight'], self.t2s_model) | |
| if config.gpt_sovits_config.is_half: | |
| self.t2s_model = self.t2s_model.half() | |
| self.t2s_model.eval() | |
| total = sum([param.nelement() for param in self.t2s_model.parameters()]) | |
| logging.info(f"Number of parameter: {total / 1e6:.2f}M") | |
| def get_speakers(self): | |
| return self.speakers | |
| def get_cleaned_text(self, text, language): | |
| phones, word2ph, norm_text = clean_text(text, language.replace("all_", "")) | |
| phones = cleaned_text_to_sequence(phones) | |
| return phones, word2ph, norm_text | |
| def get_cleaned_text_multilang(self, text): | |
| sentences = split_languages(text, expand_abbreviations=True, expand_hyphens=True) | |
| phones, word2ph, norm_text = [], [], [] | |
| for sentence, lang in sentences: | |
| lang = classify_language(sentence) | |
| _phones, _word2ph, _norm_text = self.get_cleaned_text(sentence, lang) | |
| phones.extend(_phones) | |
| word2ph.extend(_word2ph) | |
| norm_text.extend(_norm_text) | |
| return phones, word2ph, norm_text | |
| def get_bert_feature(self, text, phones, word2ph, language): | |
| if language == "zh": | |
| with torch.no_grad(): | |
| inputs = self.tokenizer(text, return_tensors="pt") | |
| for i in inputs: | |
| inputs[i] = inputs[i].to(self.device) #####输入是long不用管精度问题,精度随bert_model | |
| res = self.bert_model(**inputs, output_hidden_states=True) | |
| res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] | |
| assert len(word2ph) == len(text) | |
| phone_level_feature = [] | |
| for i in range(len(word2ph)): | |
| repeat_feature = res[i].repeat(word2ph[i], 1) | |
| phone_level_feature.append(repeat_feature) | |
| phone_level_feature = torch.cat(phone_level_feature, dim=0) | |
| # if(config.gpt_sovits_config.is_half==True):phone_level_feature=phone_level_feature.half() | |
| bert = phone_level_feature.T | |
| torch.cuda.empty_cache() | |
| else: | |
| bert = torch.zeros((1024, len(phones)), dtype=self.torch_dtype) | |
| return bert | |
| def get_bert_and_cleaned_text_multilang(self, text: list): | |
| sentences = split_languages(text, expand_abbreviations=True, expand_hyphens=True) | |
| phones, word2ph, norm_text, bert = [], [], [], [] | |
| for sentence, lang in sentences: | |
| _phones, _word2ph, _norm_text = self.get_cleaned_text(sentence, lang) | |
| _bert = self.get_bert_feature(sentence, _phones, _word2ph, _norm_text) | |
| phones.extend(_phones) | |
| if _word2ph is not None: | |
| word2ph.extend(_word2ph) | |
| norm_text.extend(_norm_text) | |
| bert.append(_bert) | |
| bert = torch.cat(bert, dim=1).to(self.device, dtype=self.torch_dtype) | |
| return phones, word2ph, norm_text, bert | |
| def get_spepc(self, audio, orig_sr): | |
| """audio的sampling_rate与模型相同""" | |
| audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=int(self.hps.data.sampling_rate)) | |
| audio = torch.FloatTensor(audio) | |
| audio_norm = audio | |
| audio_norm = audio_norm.unsqueeze(0) | |
| spec = spectrogram_torch( | |
| audio_norm, | |
| self.hps.data.filter_length, | |
| self.hps.data.sampling_rate, | |
| self.hps.data.hop_length, | |
| self.hps.data.win_length, | |
| center=False, | |
| ) | |
| return spec | |
| def _set_prompt_semantic(self, reference_audio, reference_audio_sr): | |
| zero_wav = np.zeros( | |
| int(self.sampling_rate * 0.3), | |
| dtype=np.float16 if self.is_half else np.float32, | |
| ) | |
| wav16k = librosa.resample(reference_audio, orig_sr=reference_audio_sr, target_sr=16000) | |
| with torch.no_grad(): | |
| if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000): | |
| raise OSError("参考音频在3~10秒范围外,请更换!") | |
| wav16k = torch.from_numpy(wav16k) | |
| zero_wav_torch = torch.from_numpy(zero_wav) | |
| if self.is_half == True: | |
| wav16k = wav16k.half() | |
| zero_wav_torch = zero_wav_torch.half() | |
| wav16k = wav16k.to(self.device) | |
| zero_wav_torch = zero_wav_torch.to(self.device) | |
| wav16k = torch.cat([wav16k, zero_wav_torch]).unsqueeze(0) | |
| ssl_content = self.ssl_model.model(wav16k)[ | |
| "last_hidden_state" | |
| ].transpose( | |
| 1, 2 | |
| ) # .float() | |
| codes = self.vq_model.extract_latent(ssl_content) | |
| prompt_semantic = codes[0, 0].to(self.device) | |
| # prompt_semantic = prompt_semantic.unsqueeze(0).to(self.device) | |
| self.prompt_cache["prompt_semantic"] = prompt_semantic | |
| torch.cuda.empty_cache() | |
| def get_first(self, text): | |
| pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]" | |
| text = re.split(pattern, text)[0].strip() | |
| return text | |
| def preprocess_text(self, text: str, lang: str, segment_size: int): | |
| texts = sentence_split(text, segment_size) | |
| result = [] | |
| for text in texts: | |
| phones, word2ph, norm_text, bert_features = self.get_bert_and_cleaned_text_multilang(text) | |
| res = { | |
| "phones": phones, | |
| "bert_features": bert_features, | |
| "norm_text": norm_text, | |
| } | |
| result.append(res) | |
| return result | |
| def preprocess_prompt(self, reference_audio, reference_audio_sr, prompt_text: str, prompt_lang: str): | |
| if self.prompt_cache.get("prompt_text") != prompt_text: | |
| if prompt_lang.lower() == "auto": | |
| prompt_lang = classify_language(prompt_text) | |
| if (prompt_text[-1] not in splits): | |
| prompt_text += "。" if prompt_lang != "en" else "." | |
| phones, word2ph, norm_text = self.get_cleaned_text(prompt_text, prompt_lang) | |
| bert_features = self.get_bert_feature(norm_text, phones, word2ph, prompt_lang).to(self.device, | |
| dtype=self.torch_dtype) | |
| self.prompt_cache["prompt_text"] = prompt_text | |
| self.prompt_cache["prompt_lang"] = prompt_lang | |
| self.prompt_cache["phones"] = phones | |
| self.prompt_cache["bert_features"] = bert_features | |
| self.prompt_cache["norm_text"] = norm_text | |
| self.prompt_cache["refer_spepc"] = self.get_spepc(reference_audio, orig_sr=reference_audio_sr) | |
| self._set_prompt_semantic(reference_audio, reference_audio_sr) | |
| def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): | |
| seq = sequences[0] | |
| ndim = seq.dim() | |
| if axis < 0: | |
| axis += ndim | |
| dtype: torch.dtype = seq.dtype | |
| pad_value = torch.tensor(pad_value, dtype=dtype) | |
| seq_lengths = [seq.shape[axis] for seq in sequences] | |
| if max_length is None: | |
| max_length = max(seq_lengths) | |
| else: | |
| max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length | |
| padded_sequences = [] | |
| for seq, length in zip(sequences, seq_lengths): | |
| padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1) | |
| padded_seq = torch.nn.functional.pad(seq, padding, value=pad_value) | |
| padded_sequences.append(padded_seq) | |
| batch = torch.stack(padded_sequences) | |
| return batch | |
| def to_batch(self, data: list, prompt_data: dict = None, batch_size: int = 5, threshold: float = 0.75, | |
| split_bucket: bool = True): | |
| _data: list = [] | |
| index_and_len_list = [] | |
| for idx, item in enumerate(data): | |
| norm_text_len = len(item["norm_text"]) | |
| index_and_len_list.append([idx, norm_text_len]) | |
| batch_index_list = [] | |
| if split_bucket: | |
| index_and_len_list.sort(key=lambda x: x[1]) | |
| index_and_len_list = np.array(index_and_len_list, dtype=np.int64) | |
| batch_index_list_len = 0 | |
| pos = 0 | |
| while pos < index_and_len_list.shape[0]: | |
| # batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))]) | |
| pos_end = min(pos + batch_size, index_and_len_list.shape[0]) | |
| while pos < pos_end: | |
| batch = index_and_len_list[pos:pos_end, 1].astype(np.float32) | |
| score = batch[(pos_end - pos) // 2] / batch.mean() | |
| if (score >= threshold) or (pos_end - pos == 1): | |
| batch_index = index_and_len_list[pos:pos_end, 0].tolist() | |
| batch_index_list_len += len(batch_index) | |
| batch_index_list.append(batch_index) | |
| pos = pos_end | |
| break | |
| pos_end = pos_end - 1 | |
| assert batch_index_list_len == len(data) | |
| else: | |
| for i in range(len(data)): | |
| if i % batch_size == 0: | |
| batch_index_list.append([]) | |
| batch_index_list[-1].append(i) | |
| for batch_idx, index_list in enumerate(batch_index_list): | |
| item_list = [data[idx] for idx in index_list] | |
| phones_list = [] | |
| phones_len_list = [] | |
| # bert_features_list = [] | |
| all_phones_list = [] | |
| all_phones_len_list = [] | |
| all_bert_features_list = [] | |
| norm_text_batch = [] | |
| bert_max_len = 0 | |
| phones_max_len = 0 | |
| for item in item_list: | |
| if prompt_data is not None: | |
| all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1) | |
| all_phones = torch.LongTensor(prompt_data["phones"] + item["phones"]) | |
| phones = torch.LongTensor(item["phones"]) | |
| # norm_text = prompt_data["norm_text"]+item["norm_text"] | |
| else: | |
| all_bert_features = item["bert_features"] | |
| phones = torch.LongTensor(item["phones"]) | |
| all_phones = phones | |
| # norm_text = item["norm_text"] | |
| bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) | |
| phones_max_len = max(phones_max_len, phones.shape[-1]) | |
| phones_list.append(phones) | |
| phones_len_list.append(phones.shape[-1]) | |
| all_phones_list.append(all_phones) | |
| all_phones_len_list.append(all_phones.shape[-1]) | |
| all_bert_features_list.append(all_bert_features) | |
| norm_text_batch.append(item["norm_text"]) | |
| phones_batch = phones_list | |
| max_len = max(bert_max_len, phones_max_len) | |
| # phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) | |
| all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) | |
| all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) | |
| all_bert_features_batch.zero_() | |
| for idx, item in enumerate(all_bert_features_list): | |
| if item != None: | |
| all_bert_features_batch[idx, :, : item.shape[-1]] = item | |
| batch = { | |
| "phones": phones_batch, | |
| "phones_len": torch.LongTensor(phones_len_list), | |
| "all_phones": all_phones_batch, | |
| "all_phones_len": torch.LongTensor(all_phones_len_list), | |
| "all_bert_features": all_bert_features_batch, | |
| "norm_text": norm_text_batch | |
| } | |
| _data.append(batch) | |
| return _data, batch_index_list | |
| def recovery_order(self, data: list, batch_index_list: list) -> list: | |
| ''' | |
| Recovery the order of the audio according to the batch_index_list. | |
| Args: | |
| data (List[list(np.ndarray)]): the out of order audio . | |
| batch_index_list (List[list[int]]): the batch index list. | |
| Returns: | |
| list (List[np.ndarray]): the data in the original order. | |
| ''' | |
| lenght = len(sum(batch_index_list, [])) | |
| _data = [None] * lenght | |
| for i, index_list in enumerate(batch_index_list): | |
| for j, index in enumerate(index_list): | |
| _data[index] = data[i][j] | |
| return _data | |
| def audio_postprocess(self, audio: List[torch.Tensor], sr: int, batch_index_list: list = None, | |
| speed_factor: float = 1.0, split_bucket: bool = True) -> tuple[int, np.ndarray]: | |
| zero_wav = torch.zeros( | |
| int(self.sampling_rate * 0.3), | |
| dtype=torch.float16 if self.is_half else torch.float32, | |
| device=self.device | |
| ) | |
| for i, batch in enumerate(audio): | |
| for j, audio_fragment in enumerate(batch): | |
| max_audio = torch.abs(audio_fragment).max() # 简单防止16bit爆音 | |
| if max_audio > 1: audio_fragment /= max_audio | |
| audio_fragment: torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0) | |
| audio[i][j] = audio_fragment.cpu().numpy() | |
| if split_bucket: | |
| audio = self.recovery_order(audio, batch_index_list) | |
| else: | |
| # audio = [item for batch in audio for item in batch] | |
| audio = sum(audio, []) | |
| audio = np.concatenate(audio, 0) | |
| try: | |
| if speed_factor != 1.0: | |
| audio = self.speed_change(audio, speed_factor=speed_factor, sr=int(sr)) | |
| except Exception as e: | |
| logging.error(f"Failed to change speed of audio: \n{e}") | |
| return audio | |
| def speed_change(self, input_audio: np.ndarray, speed_factor: float, sr: int): | |
| # 变速处理 | |
| processed_audio = librosa.effects.time_stretch(input_audio, rate=speed_factor) | |
| return processed_audio | |
| def infer(self, text, lang, reference_audio, reference_audio_sr, prompt_text, prompt_lang, top_k, top_p, | |
| temperature, batch_size: int = 5, batch_threshold: float = 0.75, split_bucket: bool = True, | |
| return_fragment: bool = False, speed_factor: float = 1.0, | |
| segment_size: int = config.gpt_sovits_config.segment_size, **kwargs): | |
| if return_fragment: | |
| split_bucket = False | |
| data = self.preprocess_text(text, lang, segment_size) | |
| no_prompt_text = False | |
| if check_is_none(prompt_text): | |
| no_prompt_text = True | |
| else: | |
| self.preprocess_prompt(reference_audio, reference_audio_sr, prompt_text, prompt_lang) | |
| data, batch_index_list = self.to_batch(data, | |
| prompt_data=self.prompt_cache if not no_prompt_text else None, | |
| batch_size=batch_size, | |
| threshold=batch_threshold, | |
| split_bucket=split_bucket | |
| ) | |
| audio = [] | |
| for item in data: | |
| batch_phones = item["phones"] | |
| batch_phones_len = item["phones_len"] | |
| all_phoneme_ids = item["all_phones"] | |
| all_phoneme_lens = item["all_phones_len"] | |
| all_bert_features = item["all_bert_features"] | |
| norm_text = item["norm_text"] | |
| # batch_phones = batch_phones.to(self.device) | |
| batch_phones_len = batch_phones_len.to(self.device) | |
| all_phoneme_ids = all_phoneme_ids.to(self.device) | |
| all_phoneme_lens = all_phoneme_lens.to(self.device) | |
| all_bert_features = all_bert_features.to(self.device) | |
| if self.is_half: | |
| all_bert_features = all_bert_features.half() | |
| logging.debug(f"Infer text:{[''.join(text) for text in norm_text]}") | |
| if no_prompt_text: | |
| prompt = None | |
| else: | |
| prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to( | |
| self.device) | |
| with torch.no_grad(): | |
| pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( | |
| all_phoneme_ids, | |
| all_phoneme_lens, | |
| prompt, | |
| all_bert_features, | |
| # prompt_phone_len=ph_offset, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| early_stop_num=self.hz * self.max_sec, | |
| ) | |
| refer_audio_spepc: torch.Tensor = self.prompt_cache["refer_spepc"].to(self.device) | |
| if self.is_half: | |
| refer_audio_spepc = refer_audio_spepc.half() | |
| pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] | |
| upsample_rate = math.prod(self.vq_model.upsample_rates) | |
| audio_frag_idx = [pred_semantic_list[i].shape[0] * 2 * upsample_rate for i in | |
| range(0, len(pred_semantic_list))] | |
| audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))] | |
| all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.device) | |
| _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.device) | |
| _batch_audio_fragment = (self.vq_model.decode( | |
| all_pred_semantic, _batch_phones, refer_audio_spepc | |
| ).detach()[0, 0, :]) | |
| audio_frag_end_idx.insert(0, 0) | |
| batch_audio_fragment = [_batch_audio_fragment[audio_frag_end_idx[i - 1]:audio_frag_end_idx[i]] for i in | |
| range(1, len(audio_frag_end_idx))] | |
| torch.cuda.empty_cache() | |
| if return_fragment: | |
| yield self.audio_postprocess([batch_audio_fragment], | |
| reference_audio_sr, | |
| batch_index_list, | |
| speed_factor, | |
| split_bucket) | |
| else: | |
| audio.append(batch_audio_fragment) | |
| if not return_fragment: | |
| yield self.audio_postprocess(audio, | |
| reference_audio_sr, | |
| batch_index_list, | |
| speed_factor, | |
| split_bucket) | |