File size: 5,163 Bytes
f64faed 125e69e f64faed 6fadc9b f64faed 125e69e f64faed 125e69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import torchvision.transforms as transforms
import types
import os
import sys
import mobileclip
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float16
def split_chessboard(x, num_split):
B, C, H, W = x.shape
h, w = H // num_split, W // num_split
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
return x_split
def merge_chessboard(x, num_split):
B, C, H, W = x.shape
b = B // (num_split**2)
x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
for i in range(num_split)], dim=-2)
return x_merge
class FeatureIRLayer(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, 4096),
nn.GELU(approximate='tanh')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class MobileVision(nn.Module):
def __init__(self):
super(MobileVision, self).__init__()
self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
self.vision = self.vision.image_encoder.model.eval().to(DEVICE).half()
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_embeddings(x)
x = self.forward_tokens(x)
return self.conv_exp(x)
self.vision.forward = types.MethodType(new_forward, self.vision)
self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).half()
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).half()
def forward(self, x):
with torch.no_grad():
resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
out1 = self.vision(resized_img)
x = split_chessboard(x, 2)
x = self.vision(x)
x = merge_chessboard(x, 2)
x = F.interpolate(x, size=(8, 8), mode='area')
x = torch.cat([out1, x], dim=1)
x = x.reshape(x.size(0), x.size(1), -1)
x = x.permute(0, 2, 1)
x = self.projection(x)
x = self.projection2(x)
return x
class MoondreamModel(nn.Module):
def __init__(self):
super(MoondreamModel, self).__init__()
self.vision_encoder = MobileVision()
self.text_model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2o-danube3-500m-chat",
trust_remote_code=True,
torch_dtype=DTYPE,
device_map={"": DEVICE}
)
self.load_state_dict(torch.load('moondream_model_state_dict.pt'))
def forward(self, images, tokens):
img_embs = self.vision_encoder(images)
tok_embs = self.text_model.get_input_embeddings()(tokens)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
outputs = self.text_model(inputs_embeds=inputs_embeds)
return outputs
@staticmethod
def load_model():
model = MoondreamModel().to(DEVICE).half()
return model
@staticmethod
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("h2oai/h2o-danube3-500m-chat", trust_remote_code=True)
return tokenizer
@staticmethod
def preprocess_image(image_path, img_size=512):
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.to(torch.float16)),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).to(DEVICE)
return image
@staticmethod
def generate_caption(model, image, tokenizer, max_length=128):
model.eval()
with torch.no_grad():
image = image.unsqueeze(0).to(DEVICE)
img_embs = model.vision_encoder(image)
generated = [tokenizer.bos_token_id]
descriptive_prompt = tokenizer(
f"\n\nDescriptions of the image:",
add_special_tokens=False
).input_ids
generated.extend(descriptive_prompt)
for _ in range(max_length):
input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
tok_embs = model.text_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
outputs = model.text_model(inputs_embeds=inputs_embeds)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).item()
if next_token == tokenizer.sep_token_id:
break
generated.append(next_token)
return tokenizer.decode(generated, skip_special_tokens=True)
|