edge_vlm / model.py
irotem98's picture
Update model.py
a378671 verified
raw
history blame
No virus
6.09 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import torchvision.transforms as transforms
import types
import mobileclip
# Set the device to GPU if available, otherwise use CPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
def split_chessboard(x, num_split):
B, C, H, W = x.shape
h, w = H // num_split, W // num_split
x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
return x_split
def merge_chessboard(x, num_split):
B, C, H, W = x.shape
b = B // (num_split**2)
x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
for i in range(num_split)], dim=-2)
return x_merge
class FeatureIRLayer(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, 4096),
nn.GELU(approximate='tanh')
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class MobileVision(nn.Module):
def __init__(self):
super(MobileVision, self).__init__()
self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
self.vision = self.vision.image_encoder.model.eval().to(DEVICE).to(DTYPE)
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_embeddings(x)
x = self.forward_tokens(x)
return self.conv_exp(x)
self.vision.forward = types.MethodType(new_forward, self.vision)
self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).to(DTYPE)
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
def forward(self, x):
resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
out1 = self.vision(resized_img)
x = split_chessboard(x, 2)
x = self.vision(x)
x = merge_chessboard(x, 2)
x = F.interpolate(x, size=(8, 8), mode='area')
x = torch.cat([out1, x], dim=1)
x = x.reshape(x.size(0), x.size(1), -1)
x = x.permute(0, 2, 1)
x = self.projection(x)
x = self.projection2(x)
return x
class MoondreamModel(nn.Module):
def __init__(self):
super(MoondreamModel, self).__init__()
self.vision_encoder = MobileVision()
self.text_model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2o-danube3-500m-chat",
trust_remote_code=True,
torch_dtype=DTYPE,
device_map={"": DEVICE}
)
self.load_state_dict(torch.load('moondream_model_state_dict.pt', map_location=DEVICE))
def forward(self, images, tokens):
img_embs = self.vision_encoder(images)
tok_embs = self.text_model.get_input_embeddings()(tokens)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
outputs = self.text_model(inputs_embeds=inputs_embeds)
return outputs
@staticmethod
def load_model():
model = MoondreamModel().to(DEVICE)
# Only apply half() if using a GPU
if torch.cuda.is_available():
model = model.half()
return model
@staticmethod
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("h2oai/h2o-danube3-500m-chat", trust_remote_code=True)
return tokenizer
@staticmethod
def preprocess_image(image, img_size=512):
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.to(DTYPE)),
])
image = transform(image).to(DEVICE)
return image
@staticmethod
def generate_caption(model, image, tokenizer, max_length=192):
model.eval() # Set model to evaluation mode
past_key_values = None # Initialize KV cache
with torch.no_grad(): # Disable gradients for faster inference
image = image.unsqueeze(0).to(DEVICE)
img_embs = model.vision_encoder(image)
generated = [tokenizer.bos_token_id]
descriptive_prompt = tokenizer(
f"\n\nDescriptions of the image:",
add_special_tokens=False
).input_ids
generated.extend(descriptive_prompt)
for _ in range(max_length):
input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
tok_embs = model.text_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
# Use the KV cache to avoid recomputation
outputs = model.text_model(
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=True
)
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values # Update KV cache
next_token = torch.argmax(next_token_logits, dim=-1).item()
if next_token == tokenizer.sep_token_id:
break
generated.append(next_token)
return tokenizer.decode(generated, skip_special_tokens=True)
# Example usage:
# Load the model and tokenizer
model = MoondreamModel.load_model()
tokenizer = MoondreamModel.load_tokenizer()
# Load and preprocess an image (assuming image is a PIL Image)
image = Image.open("path_to_image.jpg")
preprocessed_image = MoondreamModel.preprocess_image(image)
# Generate a caption for the image
caption = MoondreamModel.generate_caption(model, preprocessed_image, tokenizer)
print("Generated Caption:", caption)