File size: 6,092 Bytes
f64faed f7c35ba f64faed f7c35ba f64faed 6fadc9b f7c35ba f64faed f7c35ba f64faed 13a7b44 f64faed 13a7b44 f64faed f7c35ba f64faed 125e69e f7c35ba 125e69e d3d22eb 125e69e f7c35ba 125e69e 8fe453b a378671 125e69e a378671 125e69e a378671 125e69e a378671 125e69e a378671 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import torchvision.transforms as transforms
import types
import mobileclip
# Set the device to GPU if available, otherwise use CPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
def split_chessboard(x, num_split):
B, C, H, W = x.shape
h, w = H // num_split, W // num_split
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
return x_split
def merge_chessboard(x, num_split):
B, C, H, W = x.shape
b = B // (num_split**2)
x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
for i in range(num_split)], dim=-2)
return x_merge
class FeatureIRLayer(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, 4096),
nn.GELU(approximate='tanh')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class MobileVision(nn.Module):
def __init__(self):
super(MobileVision, self).__init__()
self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
self.vision = self.vision.image_encoder.model.eval().to(DEVICE).to(DTYPE)
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_embeddings(x)
x = self.forward_tokens(x)
return self.conv_exp(x)
self.vision.forward = types.MethodType(new_forward, self.vision)
self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).to(DTYPE)
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
def forward(self, x):
resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
out1 = self.vision(resized_img)
x = split_chessboard(x, 2)
x = self.vision(x)
x = merge_chessboard(x, 2)
x = F.interpolate(x, size=(8, 8), mode='area')
x = torch.cat([out1, x], dim=1)
x = x.reshape(x.size(0), x.size(1), -1)
x = x.permute(0, 2, 1)
x = self.projection(x)
x = self.projection2(x)
return x
class MoondreamModel(nn.Module):
def __init__(self):
super(MoondreamModel, self).__init__()
self.vision_encoder = MobileVision()
self.text_model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2o-danube3-500m-chat",
trust_remote_code=True,
torch_dtype=DTYPE,
device_map={"": DEVICE}
)
self.load_state_dict(torch.load('moondream_model_state_dict.pt', map_location=DEVICE))
def forward(self, images, tokens):
img_embs = self.vision_encoder(images)
tok_embs = self.text_model.get_input_embeddings()(tokens)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
outputs = self.text_model(inputs_embeds=inputs_embeds)
return outputs
@staticmethod
def load_model():
model = MoondreamModel().to(DEVICE)
# Only apply half() if using a GPU
if torch.cuda.is_available():
model = model.half()
return model
@staticmethod
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("h2oai/h2o-danube3-500m-chat", trust_remote_code=True)
return tokenizer
@staticmethod
def preprocess_image(image, img_size=512):
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.to(DTYPE)),
])
image = transform(image).to(DEVICE)
return image
@staticmethod
def generate_caption(model, image, tokenizer, max_length=192):
model.eval() # Set model to evaluation mode
past_key_values = None # Initialize KV cache
with torch.no_grad(): # Disable gradients for faster inference
image = image.unsqueeze(0).to(DEVICE)
img_embs = model.vision_encoder(image)
generated = [tokenizer.bos_token_id]
descriptive_prompt = tokenizer(
f"\n\nDescriptions of the image:",
add_special_tokens=False
).input_ids
generated.extend(descriptive_prompt)
for _ in range(max_length):
input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
tok_embs = model.text_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
# Use the KV cache to avoid recomputation
outputs = model.text_model(
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=True
)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values # Update KV cache
next_token = torch.argmax(next_token_logits, dim=-1).item()
if next_token == tokenizer.sep_token_id:
break
generated.append(next_token)
return tokenizer.decode(generated, skip_special_tokens=True)
# Example usage:
# Load the model and tokenizer
model = MoondreamModel.load_model()
tokenizer = MoondreamModel.load_tokenizer()
# Load and preprocess an image (assuming image is a PIL Image)
image = Image.open("path_to_image.jpg")
preprocessed_image = MoondreamModel.preprocess_image(image)
# Generate a caption for the image
caption = MoondreamModel.generate_caption(model, preprocessed_image, tokenizer)
print("Generated Caption:", caption)
|