Update model.py
Browse files
model.py
CHANGED
@@ -101,16 +101,17 @@ class MoondreamModel(nn.Module):
|
|
101 |
return tokenizer
|
102 |
|
103 |
@staticmethod
|
104 |
-
def preprocess_image(
|
105 |
transform = transforms.Compose([
|
106 |
transforms.Resize((img_size, img_size)),
|
107 |
transforms.ToTensor(),
|
108 |
transforms.Lambda(lambda x: x.to(DTYPE)),
|
109 |
])
|
110 |
-
image
|
111 |
image = transform(image).to(DEVICE)
|
112 |
return image
|
113 |
|
|
|
114 |
@staticmethod
|
115 |
def generate_caption(model, image, tokenizer, max_length=128):
|
116 |
model.eval()
|
|
|
101 |
return tokenizer
|
102 |
|
103 |
@staticmethod
|
104 |
+
def preprocess_image(image, img_size=512):
|
105 |
transform = transforms.Compose([
|
106 |
transforms.Resize((img_size, img_size)),
|
107 |
transforms.ToTensor(),
|
108 |
transforms.Lambda(lambda x: x.to(DTYPE)),
|
109 |
])
|
110 |
+
# The `image` is now a PIL image, so no need to load it from the file path
|
111 |
image = transform(image).to(DEVICE)
|
112 |
return image
|
113 |
|
114 |
+
|
115 |
@staticmethod
|
116 |
def generate_caption(model, image, tokenizer, max_length=128):
|
117 |
model.eval()
|