irotem98 commited on
Commit
a378671
1 Parent(s): 8fe453b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +28 -7
model.py CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  import types
8
- import os
9
  import mobileclip
10
 
11
  # Set the device to GPU if available, otherwise use CPU
@@ -106,15 +105,15 @@ class MoondreamModel(nn.Module):
106
  transforms.ToTensor(),
107
  transforms.Lambda(lambda x: x.to(DTYPE)),
108
  ])
109
- # The `image` is now a PIL image, so no need to load it from the file path
110
  image = transform(image).to(DEVICE)
111
  return image
112
 
113
-
114
  @staticmethod
115
  def generate_caption(model, image, tokenizer, max_length=192):
116
- model.eval()
117
- with torch.no_grad():
 
 
118
  image = image.unsqueeze(0).to(DEVICE)
119
  img_embs = model.vision_encoder(image)
120
 
@@ -129,13 +128,35 @@ class MoondreamModel(nn.Module):
129
  input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
130
  tok_embs = model.text_model.get_input_embeddings()(input_ids)
131
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
132
- outputs = model.text_model(inputs_embeds=inputs_embeds)
 
 
 
 
 
 
 
133
  next_token_logits = outputs.logits[:, -1, :]
134
- next_token = torch.argmax(next_token_logits, dim=-1).item()
135
 
 
136
  if next_token == tokenizer.sep_token_id:
137
  break
138
 
139
  generated.append(next_token)
140
 
141
  return tokenizer.decode(generated, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  import types
 
8
  import mobileclip
9
 
10
  # Set the device to GPU if available, otherwise use CPU
 
105
  transforms.ToTensor(),
106
  transforms.Lambda(lambda x: x.to(DTYPE)),
107
  ])
 
108
  image = transform(image).to(DEVICE)
109
  return image
110
 
 
111
  @staticmethod
112
  def generate_caption(model, image, tokenizer, max_length=192):
113
+ model.eval() # Set model to evaluation mode
114
+ past_key_values = None # Initialize KV cache
115
+
116
+ with torch.no_grad(): # Disable gradients for faster inference
117
  image = image.unsqueeze(0).to(DEVICE)
118
  img_embs = model.vision_encoder(image)
119
 
 
128
  input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
129
  tok_embs = model.text_model.get_input_embeddings()(input_ids)
130
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
131
+
132
+ # Use the KV cache to avoid recomputation
133
+ outputs = model.text_model(
134
+ inputs_embeds=inputs_embeds,
135
+ past_key_values=past_key_values,
136
+ use_cache=True
137
+ )
138
+
139
  next_token_logits = outputs.logits[:, -1, :]
140
+ past_key_values = outputs.past_key_values # Update KV cache
141
 
142
+ next_token = torch.argmax(next_token_logits, dim=-1).item()
143
  if next_token == tokenizer.sep_token_id:
144
  break
145
 
146
  generated.append(next_token)
147
 
148
  return tokenizer.decode(generated, skip_special_tokens=True)
149
+
150
+ # Example usage:
151
+
152
+ # Load the model and tokenizer
153
+ model = MoondreamModel.load_model()
154
+ tokenizer = MoondreamModel.load_tokenizer()
155
+
156
+ # Load and preprocess an image (assuming image is a PIL Image)
157
+ image = Image.open("path_to_image.jpg")
158
+ preprocessed_image = MoondreamModel.preprocess_image(image)
159
+
160
+ # Generate a caption for the image
161
+ caption = MoondreamModel.generate_caption(model, preprocessed_image, tokenizer)
162
+ print("Generated Caption:", caption)