| | from typing import List |
| | from queue import Queue |
| |
|
| | import torch |
| | from PIL import Image |
| | from copy import deepcopy |
| | import requests, os |
| |
|
| | IMAGE_TOKEN_INDEX=-200 |
| | blacklist = ['<image>', '<s>', '</s>'] |
| | max_num_images = 3 |
| |
|
| | def input_moderation(texts: list[list[str]]): |
| | |
| | for text_pair in texts: |
| | |
| | for b in blacklist: |
| | text_pair[0] = text_pair[0].replace(b, '') |
| | if text_pair[1] is not None: |
| | text_pair[1] = text_pair[1].replace(b, '') |
| | |
| | return texts |
| |
|
| | def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'): |
| | for _ in range(num_images): |
| | t = f"{placeholder}{sep}" + t |
| | return t |
| |
|
| | def get_conv(texts): |
| | ret = [] |
| | |
| | for conv in texts: |
| | ret.append({'from': 'human', 'value': conv[0]}) |
| | ret.append({'from': 'gpt', 'value': conv[1]}) |
| |
|
| | return ret |
| |
|
| | |
| | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): |
| | prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')] |
| |
|
| | def insert_separator(X, sep): |
| | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
| |
|
| | input_ids = [] |
| | offset = 0 |
| | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
| | offset = 1 |
| | input_ids.append(prompt_chunks[0][0]) |
| |
|
| | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
| | input_ids.extend(x[offset:]) |
| |
|
| | if return_tensors is not None: |
| | if return_tensors == 'pt': |
| | return torch.tensor(input_ids, dtype=torch.long) |
| | raise ValueError(f'Unsupported tensor type: {return_tensors}') |
| | return input_ids |
| | |
| | def preprocess(tokenizer, data: list, return_tensors='pt'): |
| | ''' |
| | [ |
| | { |
| | 'from': 'human', |
| | 'value': xxx, |
| | }, |
| | { |
| | 'from': 'gpt', |
| | 'value': xxx |
| | } |
| | ] |
| | ''' |
| | |
| | if not isinstance(data, list): |
| | raise ValueError('must be a list') |
| |
|
| | |
| | return preprocess_allava(tokenizer, data, return_tensors=return_tensors) |
| |
|
| | |
| |
|
| | def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: |
| | input_ids = None |
| | for ind, conv in enumerate(convs): |
| | if ind % 2 == 0: |
| | h = conv['value'].strip() |
| | h = f"USER: {h} " |
| | cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors) |
| | |
| | if input_ids is None: |
| | input_ids = cur_input_ids |
| | else: |
| | input_ids = torch.cat([input_ids, cur_input_ids]) |
| |
|
| | else: |
| | g = conv['value'] |
| | if g is not None: |
| | cur_input_ids = self.tokenizer(f"ASSISTANT: {g}</s>", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] |
| | input_ids = torch.cat([input_ids, cur_input_ids]) |
| | else: |
| | cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] |
| | input_ids = torch.cat([input_ids, cur_input_ids]) |
| |
|
| |
|
| | return input_ids |
| |
|
| | def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: |
| | input_ids = None |
| |
|
| | for ind, conv in enumerate(convs): |
| | if ind % 2 == 0: |
| | h = conv['value'].strip() |
| | h = f"[INST] {h} [/INST] " |
| | cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors) |
| | |
| | if input_ids is None: |
| | input_ids = cur_input_ids |
| | else: |
| | input_ids = torch.cat([input_ids, cur_input_ids]) |
| |
|
| | else: |
| | g = conv['value'] |
| | if g is not None: |
| | cur_input_ids = tokenizer(f"{g}{tokenizer.eos_token}", add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0] |
| | input_ids = torch.cat([input_ids, cur_input_ids]) |
| |
|
| | return input_ids |
| |
|
| |
|
| | |
| | def get_image_tensors(processor, images, device): |
| | list_image_tensors = [] |
| | crop_size = processor.crop_size |
| | for fp in images: |
| | if fp is None: |
| | list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device)) |
| | continue |
| | elif isinstance(fp, str): |
| | image = Image.open(fp).convert('RGB') |
| | elif isinstance(fp, Image.Image): |
| | image = fp |
| | else: |
| | raise TypeError(f'Unsupported type {type(fp)}') |
| |
|
| | |
| | if True: |
| | |
| | def expand2square(pil_img, background_color): |
| | width, height = pil_img.size |
| | if pil_img.mode == 'L': |
| | pil_img = pil_img.convert('RGB') |
| |
|
| | if width == height: |
| | return pil_img |
| | elif width > height: |
| | result = Image.new(pil_img.mode, (width, width), background_color) |
| | result.paste(pil_img, (0, (width - height) // 2)) |
| | return result |
| | else: |
| | result = Image.new(pil_img.mode, (height, height), background_color) |
| | result.paste(pil_img, ((height - width) // 2, 0)) |
| | return result |
| | |
| | image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) |
| | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| | else: |
| | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| | list_image_tensors.append(image.to(device)) |
| | |
| | return list_image_tensors |
| |
|
| |
|
| |
|
| |
|
| | def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'): |
| | ''' |
| | texts: [[]] |
| | ''' |
| |
|
| | |
| | |
| | |
| | if isinstance(texts, str): |
| | texts = [[texts, None]] |
| | else: |
| | assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list' |
| | |
| | if history is not None: |
| | texts = history + texts |
| |
|
| | texts = input_moderation(texts) |
| |
|
| |
|
| | |
| | |
| | |
| | if isinstance(images, str) or isinstance(images, Image.Image): |
| | images = [images] |
| |
|
| | valid_images = [] |
| | if images is None: |
| | images = [None] |
| | |
| | for img in images: |
| | try: |
| | if os.path.exists(img): |
| | img = Image.open(img).convert('RGB') |
| | else: |
| | img = Image.open(requests.get(img, stream=True).raw) |
| |
|
| | valid_images.append(img) |
| | except: |
| | continue |
| | |
| | images = valid_images |
| |
|
| | if images == []: |
| | images = [None] |
| | |
| |
|
| | assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported' |
| |
|
| | |
| | |
| | |
| |
|
| | history = deepcopy(texts) |
| |
|
| | |
| | image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) |
| | texts[0][0] = image_place_holder_inserted |
| |
|
| | |
| | conv = get_conv(texts) |
| |
|
| | |
| | input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device) |
| |
|
| | list_image_tensors = get_image_tensors(processor, images, device) |
| | image_tensors = torch.stack(list_image_tensors) |
| |
|
| | try: |
| | dtype = torch.bfloat16 |
| | |
| | torch.tensor(1, dtype=dtype).cuda() |
| | except: |
| | |
| | dtype = torch.float16 |
| |
|
| | if return_history: |
| | return input_ids, image_tensors, history |
| | |
| | return input_ids, image_tensors, None |
| |
|
| |
|
| |
|
| | class TextIterStreamer: |
| | def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): |
| | self.tokenizer = tokenizer |
| | self.skip_prompt = skip_prompt |
| | self.skip_special_tokens = skip_special_tokens |
| | self.tokens = [] |
| | self.text_queue = Queue() |
| | self.next_tokens_are_prompt = True |
| |
|
| | def put(self, value): |
| | if self.skip_prompt and self.next_tokens_are_prompt: |
| | self.next_tokens_are_prompt = False |
| | else: |
| | if len(value.shape) > 1: |
| | value = value[0] |
| | self.tokens.extend(value.tolist()) |
| | self.text_queue.put( |
| | self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) |
| |
|
| | def end(self): |
| | self.text_queue.put(None) |
| |
|
| | def __iter__(self): |
| | return self |
| |
|
| | def __next__(self): |
| | value = self.text_queue.get() |
| | if value is None: |
| | raise StopIteration() |
| | else: |
| | return value |
| |
|