irotem98 commited on
Commit
f7c35ba
1 Parent(s): 61566d8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +11 -8
model.py CHANGED
@@ -6,11 +6,11 @@ from PIL import Image
6
  import torchvision.transforms as transforms
7
  import types
8
  import os
9
- import sys
10
  import mobileclip
11
 
 
12
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
- DTYPE = torch.float16
14
 
15
  def split_chessboard(x, num_split):
16
  B, C, H, W = x.shape
@@ -40,7 +40,7 @@ class MobileVision(nn.Module):
40
  def __init__(self):
41
  super(MobileVision, self).__init__()
42
  self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
43
- self.vision = self.vision.image_encoder.model.eval().to(DEVICE).half()
44
 
45
  def new_forward(self, x: torch.Tensor) -> torch.Tensor:
46
  x = self.forward_embeddings(x)
@@ -48,8 +48,8 @@ class MobileVision(nn.Module):
48
  return self.conv_exp(x)
49
  self.vision.forward = types.MethodType(new_forward, self.vision)
50
 
51
- self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).half()
52
- self.projection2 = nn.Linear(4096, 1536).to(DEVICE).half()
53
 
54
  def forward(self, x):
55
  with torch.no_grad():
@@ -78,7 +78,7 @@ class MoondreamModel(nn.Module):
78
  torch_dtype=DTYPE,
79
  device_map={"": DEVICE}
80
  )
81
- self.load_state_dict(torch.load('moondream_model_state_dict.pt'))
82
 
83
  def forward(self, images, tokens):
84
  img_embs = self.vision_encoder(images)
@@ -89,7 +89,10 @@ class MoondreamModel(nn.Module):
89
 
90
  @staticmethod
91
  def load_model():
92
- model = MoondreamModel().to(DEVICE).half()
 
 
 
93
  return model
94
 
95
  @staticmethod
@@ -102,7 +105,7 @@ class MoondreamModel(nn.Module):
102
  transform = transforms.Compose([
103
  transforms.Resize((img_size, img_size)),
104
  transforms.ToTensor(),
105
- transforms.Lambda(lambda x: x.to(torch.float16)),
106
  ])
107
  image = Image.open(image_path).convert('RGB')
108
  image = transform(image).to(DEVICE)
 
6
  import torchvision.transforms as transforms
7
  import types
8
  import os
 
9
  import mobileclip
10
 
11
+ # Set the device to GPU if available, otherwise use CPU
12
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
15
  def split_chessboard(x, num_split):
16
  B, C, H, W = x.shape
 
40
  def __init__(self):
41
  super(MobileVision, self).__init__()
42
  self.vision, _, _ = mobileclip.create_model_and_transforms('mobileclip_s2', pretrained='mobileclip_s2.pt')
43
+ self.vision = self.vision.image_encoder.model.eval().to(DEVICE).to(DTYPE)
44
 
45
  def new_forward(self, x: torch.Tensor) -> torch.Tensor:
46
  x = self.forward_embeddings(x)
 
48
  return self.conv_exp(x)
49
  self.vision.forward = types.MethodType(new_forward, self.vision)
50
 
51
+ self.projection = FeatureIRLayer(1280*2, 4096, 1536).to(DEVICE).to(DTYPE)
52
+ self.projection2 = nn.Linear(4096, 1536).to(DEVICE).to(DTYPE)
53
 
54
  def forward(self, x):
55
  with torch.no_grad():
 
78
  torch_dtype=DTYPE,
79
  device_map={"": DEVICE}
80
  )
81
+ self.load_state_dict(torch.load('moondream_model_state_dict.pt', map_location=DEVICE))
82
 
83
  def forward(self, images, tokens):
84
  img_embs = self.vision_encoder(images)
 
89
 
90
  @staticmethod
91
  def load_model():
92
+ model = MoondreamModel().to(DEVICE)
93
+ # Only apply half() if using a GPU
94
+ if torch.cuda.is_available():
95
+ model = model.half()
96
  return model
97
 
98
  @staticmethod
 
105
  transform = transforms.Compose([
106
  transforms.Resize((img_size, img_size)),
107
  transforms.ToTensor(),
108
+ transforms.Lambda(lambda x: x.to(DTYPE)),
109
  ])
110
  image = Image.open(image_path).convert('RGB')
111
  image = transform(image).to(DEVICE)