| | from ..models.model_manager import ModelManager |
| | import torch |
| |
|
| |
|
| |
|
| | def tokenize_long_prompt(tokenizer, prompt, max_length=None): |
| | |
| | length = tokenizer.model_max_length if max_length is None else max_length |
| |
|
| | |
| | tokenizer.model_max_length = 99999999 |
| |
|
| | |
| | input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
| |
|
| | |
| | max_length = (input_ids.shape[1] + length - 1) // length * length |
| |
|
| | |
| | tokenizer.model_max_length = length |
| | |
| | |
| | input_ids = tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | padding="max_length", |
| | max_length=max_length, |
| | truncation=True |
| | ).input_ids |
| |
|
| | |
| | num_sentence = input_ids.shape[1] // length |
| | input_ids = input_ids.reshape((num_sentence, length)) |
| | |
| | return input_ids |
| |
|
| |
|
| |
|
| | class BasePrompter: |
| | def __init__(self): |
| | self.refiners = [] |
| | self.extenders = [] |
| |
|
| |
|
| | def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): |
| | for refiner_class in refiner_classes: |
| | refiner = refiner_class.from_model_manager(model_manager) |
| | self.refiners.append(refiner) |
| | |
| | def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]): |
| | for extender_class in extender_classes: |
| | extender = extender_class.from_model_manager(model_manager) |
| | self.extenders.append(extender) |
| |
|
| |
|
| | @torch.no_grad() |
| | def process_prompt(self, prompt, positive=True): |
| | if isinstance(prompt, list): |
| | prompt = [self.process_prompt(prompt_, positive=positive) for prompt_ in prompt] |
| | else: |
| | for refiner in self.refiners: |
| | prompt = refiner(prompt, positive=positive) |
| | return prompt |
| |
|
| | @torch.no_grad() |
| | def extend_prompt(self, prompt:str, positive=True): |
| | extended_prompt = dict(prompt=prompt) |
| | for extender in self.extenders: |
| | extended_prompt = extender(extended_prompt) |
| | return extended_prompt |