import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM from PIL import Image import torchvision.transforms as transforms import types import sys sys.path.insert(0,'ml-mobileclip/') import mobileclip DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' DTYPE = torch.float16 def split_chessboard(x, num_split): B, C, H, W = x.shape h, w = H // num_split, W // num_split x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0) return x_split def merge_chessboard(x, num_split): B, C, H, W = x.shape b = B // (num_split**2) x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1) for i in range(num_split)], dim=-2) return x_merge class FeatureIRLayer(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None: super().__init__() self.mlp = nn.Sequential( nn.Linear(in_dim, 4096), nn.GELU(approximate='tanh') ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class MobileVision(nn.Module): def __init__(self): super(MobileVision, self).__init__() self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt') self.vision = self.vision.image_encoder.model.eval().to(DEVICE).half() def new_forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_embeddings(x) x = self.forward_tokens(x) return self.conv_exp(x) self.vision.forward = types.MethodType(new_forward, self.vision) self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).half() self.projection2 = nn.Linear(4096, 1536).to(DEVICE).half() def forward(self, x): with torch.no_grad(): resized_img = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False) out1 = self.vision(resized_img) x = split_chessboard(x, 2) x = self.vision(x) x = merge_chessboard(x, 2) x = F.interpolate(x, size=(8, 8), mode='area') x = torch.cat([out1, x], dim=1) x = x.reshape(x.size(0), x.size(1), -1) x = x.permute(0, 2, 1) x = self.projection(x) x = self.projection2(x) return x class MoondreamModel(nn.Module): def __init__(self): super(MoondreamModel, self).__init__() self.vision_encoder = MobileVision() self.text_model = AutoModelForCausalLM.from_pretrained( "h2oai/h2o-danube3-500m-chat", trust_remote_code=True, torch_dtype=DTYPE, device_map={"": DEVICE} ) def forward(self, images, tokens): img_embs = self.vision_encoder(images) tok_embs = self.text_model.get_input_embeddings()(tokens) inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1) outputs = self.text_model(inputs_embeds=inputs_embeds) return outputs