Update model.py
Browse files
model.py
CHANGED
@@ -6,11 +6,11 @@ from PIL import Image
|
|
6 |
import torchvision.transforms as transforms
|
7 |
import types
|
8 |
import os
|
9 |
-
import sys
|
10 |
import mobileclip
|
11 |
|
|
|
12 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
-
DTYPE = torch.float16
|
14 |
|
15 |
def split_chessboard(x, num_split):
|
16 |
B, C, H, W = x.shape
|
@@ -40,7 +40,7 @@ class MobileVision(nn.Module):
|
|
40 |
def __init__(self):
|
41 |
super(MobileVision, self).__init__()
|
42 |
self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
|
43 |
-
self.vision = self.vision.image_encoder.model.eval().to(DEVICE).
|
44 |
|
45 |
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
46 |
x = self.forward_embeddings(x)
|
@@ -48,8 +48,8 @@ class MobileVision(nn.Module):
|
|
48 |
return self.conv_exp(x)
|
49 |
self.vision.forward = types.MethodType(new_forward, self.vision)
|
50 |
|
51 |
-
self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).
|
52 |
-
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).
|
53 |
|
54 |
def forward(self, x):
|
55 |
with torch.no_grad():
|
@@ -78,7 +78,7 @@ class MoondreamModel(nn.Module):
|
|
78 |
torch_dtype=DTYPE,
|
79 |
device_map={"": DEVICE}
|
80 |
)
|
81 |
-
self.load_state_dict(torch.load('moondream_model_state_dict.pt'))
|
82 |
|
83 |
def forward(self, images, tokens):
|
84 |
img_embs = self.vision_encoder(images)
|
@@ -89,7 +89,10 @@ class MoondreamModel(nn.Module):
|
|
89 |
|
90 |
@staticmethod
|
91 |
def load_model():
|
92 |
-
model = MoondreamModel().to(DEVICE)
|
|
|
|
|
|
|
93 |
return model
|
94 |
|
95 |
@staticmethod
|
@@ -102,7 +105,7 @@ class MoondreamModel(nn.Module):
|
|
102 |
transform = transforms.Compose([
|
103 |
transforms.Resize((img_size, img_size)),
|
104 |
transforms.ToTensor(),
|
105 |
-
transforms.Lambda(lambda x: x.to(
|
106 |
])
|
107 |
image = Image.open(image_path).convert('RGB')
|
108 |
image = transform(image).to(DEVICE)
|
|
|
6 |
import torchvision.transforms as transforms
|
7 |
import types
|
8 |
import os
|
|
|
9 |
import mobileclip
|
10 |
|
11 |
+
# Set the device to GPU if available, otherwise use CPU
|
12 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
14 |
|
15 |
def split_chessboard(x, num_split):
|
16 |
B, C, H, W = x.shape
|
|
|
40 |
def __init__(self):
|
41 |
super(MobileVision, self).__init__()
|
42 |
self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
|
43 |
+
self.vision = self.vision.image_encoder.model.eval().to(DEVICE).to(DTYPE)
|
44 |
|
45 |
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
46 |
x = self.forward_embeddings(x)
|
|
|
48 |
return self.conv_exp(x)
|
49 |
self.vision.forward = types.MethodType(new_forward, self.vision)
|
50 |
|
51 |
+
self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).to(DTYPE)
|
52 |
+
self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
|
53 |
|
54 |
def forward(self, x):
|
55 |
with torch.no_grad():
|
|
|
78 |
torch_dtype=DTYPE,
|
79 |
device_map={"": DEVICE}
|
80 |
)
|
81 |
+
self.load_state_dict(torch.load('moondream_model_state_dict.pt', map_location=DEVICE))
|
82 |
|
83 |
def forward(self, images, tokens):
|
84 |
img_embs = self.vision_encoder(images)
|
|
|
89 |
|
90 |
@staticmethod
|
91 |
def load_model():
|
92 |
+
model = MoondreamModel().to(DEVICE)
|
93 |
+
# Only apply half() if using a GPU
|
94 |
+
if torch.cuda.is_available():
|
95 |
+
model = model.half()
|
96 |
return model
|
97 |
|
98 |
@staticmethod
|
|
|
105 |
transform = transforms.Compose([
|
106 |
transforms.Resize((img_size, img_size)),
|
107 |
transforms.ToTensor(),
|
108 |
+
transforms.Lambda(lambda x: x.to(DTYPE)),
|
109 |
])
|
110 |
image = Image.open(image_path).convert('RGB')
|
111 |
image = transform(image).to(DEVICE)
|