Update model.py
Browse files
model.py
CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
5 |
from PIL import Image
|
6 |
import torchvision.transforms as transforms
|
7 |
import types
|
8 |
-
import os
|
9 |
import mobileclip
|
10 |
|
11 |
# Set the device to GPU if available, otherwise use CPU
|
@@ -106,15 +105,15 @@ class MoondreamModel(nn.Module):
|
|
106 |
transforms.ToTensor(),
|
107 |
transforms.Lambda(lambda x: x.to(DTYPE)),
|
108 |
])
|
109 |
-
# The `image` is now a PIL image, so no need to load it from the file path
|
110 |
image = transform(image).to(DEVICE)
|
111 |
return image
|
112 |
|
113 |
-
|
114 |
@staticmethod
|
115 |
def generate_caption(model, image, tokenizer, max_length=192):
|
116 |
-
model.eval()
|
117 |
-
|
|
|
|
|
118 |
image = image.unsqueeze(0).to(DEVICE)
|
119 |
img_embs = model.vision_encoder(image)
|
120 |
|
@@ -129,13 +128,35 @@ class MoondreamModel(nn.Module):
|
|
129 |
input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
130 |
tok_embs = model.text_model.get_input_embeddings()(input_ids)
|
131 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
next_token_logits = outputs.logits[:, -1, :]
|
134 |
-
|
135 |
|
|
|
136 |
if next_token == tokenizer.sep_token_id:
|
137 |
break
|
138 |
|
139 |
generated.append(next_token)
|
140 |
|
141 |
return tokenizer.decode(generated, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
import torchvision.transforms as transforms
|
7 |
import types
|
|
|
8 |
import mobileclip
|
9 |
|
10 |
# Set the device to GPU if available, otherwise use CPU
|
|
|
105 |
transforms.ToTensor(),
|
106 |
transforms.Lambda(lambda x: x.to(DTYPE)),
|
107 |
])
|
|
|
108 |
image = transform(image).to(DEVICE)
|
109 |
return image
|
110 |
|
|
|
111 |
@staticmethod
|
112 |
def generate_caption(model, image, tokenizer, max_length=192):
|
113 |
+
model.eval() # Set model to evaluation mode
|
114 |
+
past_key_values = None # Initialize KV cache
|
115 |
+
|
116 |
+
with torch.no_grad(): # Disable gradients for faster inference
|
117 |
image = image.unsqueeze(0).to(DEVICE)
|
118 |
img_embs = model.vision_encoder(image)
|
119 |
|
|
|
128 |
input_ids = torch.tensor(generated, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
129 |
tok_embs = model.text_model.get_input_embeddings()(input_ids)
|
130 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
131 |
+
|
132 |
+
# Use the KV cache to avoid recomputation
|
133 |
+
outputs = model.text_model(
|
134 |
+
inputs_embeds=inputs_embeds,
|
135 |
+
past_key_values=past_key_values,
|
136 |
+
use_cache=True
|
137 |
+
)
|
138 |
+
|
139 |
next_token_logits = outputs.logits[:, -1, :]
|
140 |
+
past_key_values = outputs.past_key_values # Update KV cache
|
141 |
|
142 |
+
next_token = torch.argmax(next_token_logits, dim=-1).item()
|
143 |
if next_token == tokenizer.sep_token_id:
|
144 |
break
|
145 |
|
146 |
generated.append(next_token)
|
147 |
|
148 |
return tokenizer.decode(generated, skip_special_tokens=True)
|
149 |
+
|
150 |
+
# Example usage:
|
151 |
+
|
152 |
+
# Load the model and tokenizer
|
153 |
+
model = MoondreamModel.load_model()
|
154 |
+
tokenizer = MoondreamModel.load_tokenizer()
|
155 |
+
|
156 |
+
# Load and preprocess an image (assuming image is a PIL Image)
|
157 |
+
image = Image.open("path_to_image.jpg")
|
158 |
+
preprocessed_image = MoondreamModel.preprocess_image(image)
|
159 |
+
|
160 |
+
# Generate a caption for the image
|
161 |
+
caption = MoondreamModel.generate_caption(model, preprocessed_image, tokenizer)
|
162 |
+
print("Generated Caption:", caption)
|