from PIL import Image from io import BytesIO import base64 import math import ast import torch from transformers import StoppingCriteria from oryx.constants import IMAGE_TOKEN_INDEX import os video_base = 0 video_ps = 64 highres_base = 0 highres_ps = 32 MAXRES = 1536 MINRES = 0 VIDEO_MAXRES = 480 VIDEO_MINRES = 288 LOWRES_RESIZE = (384,32) PAD2STRIDE=False def pad_image(image, target_resolution, value=0): """ Resize and pad an image to a target resolution while maintaining aspect ratio. Args: image (PIL.Image.Image): The input image. target_resolution (tuple): The target resolution (width, height) of the image. Returns: PIL.Image.Image: The resized and padded image. """ original_width, original_height = image.size target_width, target_height = target_resolution # Create a new image with the target size and paste the resized image onto it new_image = Image.new('RGB', (target_width, target_height), (value, value, value)) paste_x = (target_width - original_width) // 2 paste_y = (target_height - original_height) // 2 new_image.paste(image, (paste_x, paste_y)) return new_image def resize_images(image, patch_size=14, base_size=896): h, w = image.size if base_size == 0: if h * w > MAXRES * MAXRES: # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') scale = MAXRES * MAXRES / (h * w) scale = math.sqrt(scale) elif h * w < MINRES * MINRES: # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') scale = MINRES * MINRES / (h * w) scale = math.sqrt(scale) else: scale = None else: scale = base_size * base_size / (h * w) scale = math.sqrt(scale) if scale is not None: new_h = int(h * scale / patch_size) * patch_size new_w = int(w * scale / patch_size) * patch_size image = image.resize((new_h, new_w)) elif PAD2STRIDE: if h % patch_size == 0: new_h = h else: new_h = (h // patch_size + 1) * patch_size if w % patch_size == 0: new_w = w else: new_w = (w // patch_size + 1) * patch_size image = pad_image(image, (new_h, new_w), value=127) else: scale = 1.0 new_h = int(h * scale / patch_size) * patch_size new_w = int(w * scale / patch_size) * patch_size image = image.resize((new_h, new_w)) return image def resize_video(image, patch_size=14, base_size=896): h, w = image.size if base_size == 0: if h * w > VIDEO_MAXRES * VIDEO_MAXRES: # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w) scale = math.sqrt(scale) elif h * w < VIDEO_MINRES * VIDEO_MINRES: # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') scale = VIDEO_MINRES * VIDEO_MINRES / (h * w) scale = math.sqrt(scale) else: scale = None else: scale = base_size * base_size / (h * w) scale = math.sqrt(scale) if scale is not None: new_h = int(h * scale / patch_size) * patch_size new_w = int(w * scale / patch_size) * patch_size image = image.resize((new_h, new_w)) elif PAD2STRIDE: if h % patch_size == 0: new_h = h else: new_h = (h // patch_size + 1) * patch_size if w % patch_size == 0: new_w = w else: new_w = (w // patch_size + 1) * patch_size image = pad_image(image, (new_h, new_w), value=127) else: scale = 1.0 new_h = int(h * scale / patch_size) * patch_size new_w = int(w * scale / patch_size) * patch_size image = image.resize((new_h, new_w)) return image def process_anyres_video_genli(image, processor): image = resize_video(image, patch_size=video_ps, base_size=video_base) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] return image.unsqueeze(0) def process_anyres_video_genli_long(image, processor): image = resize_video(image, patch_size=video_ps * 2, base_size=video_base) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] return image.unsqueeze(0) def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def process_anyres_highres_image_genli(image, processor): h, w = image.size if h < 32 and w < 32: min_size = min(h, w) ratio = 64 / min_size image = image.resize((int(h * ratio), int(w * ratio))) elif h < 32: ratio = 64 / h image = image.resize((int(h * ratio), int(w * ratio))) elif w < 32: ratio = 64 / w image = image.resize((int(h * ratio), int(w * ratio))) image = resize_images(image, patch_size=highres_ps, base_size=highres_base) image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0]) # image_patches = [image_original_resize] + [image_original_resize] # image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] # for image_patch in image_patches] image_patches = processor.preprocess(image_original_resize, return_tensors='pt')['pixel_values'][0] image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0) return image_patches.unsqueeze(0), image_padded.unsqueeze(0) def read_image_patch(patch_info): if 'img_path' in patch_info.keys(): image = Image.open(patch_info['img_path']).convert('RGB') else: if 'image_encoing' in patch_info.keys(): patch_info['image_encoding'] = patch_info['image_encoing'] image_file_name = patch_info['patch'] start_bytes = int(patch_info['start_num']) file_size = int(patch_info['size']) with open(image_file_name, 'rb') as f: f.seek(start_bytes) if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64': image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB") else: image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB") return image def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, 3) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if output_ids[0, -keyword_id.shape[0]:] == keyword_id: return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False