irotem98 commited on
Commit
d3d22eb
1 Parent(s): f7c35ba

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -2
model.py CHANGED
@@ -101,16 +101,17 @@ class MoondreamModel(nn.Module):
101
  return tokenizer
102
 
103
  @staticmethod
104
- def preprocess_image(image_path, img_size=512):
105
  transform = transforms.Compose([
106
  transforms.Resize((img_size, img_size)),
107
  transforms.ToTensor(),
108
  transforms.Lambda(lambda x: x.to(DTYPE)),
109
  ])
110
- image = Image.open(image_path).convert('RGB')
111
  image = transform(image).to(DEVICE)
112
  return image
113
 
 
114
  @staticmethod
115
  def generate_caption(model, image, tokenizer, max_length=128):
116
  model.eval()
 
101
  return tokenizer
102
 
103
  @staticmethod
104
+ def preprocess_image(image, img_size=512):
105
  transform = transforms.Compose([
106
  transforms.Resize((img_size, img_size)),
107
  transforms.ToTensor(),
108
  transforms.Lambda(lambda x: x.to(DTYPE)),
109
  ])
110
+ # The `image` is now a PIL image, so no need to load it from the file path
111
  image = transform(image).to(DEVICE)
112
  return image
113
 
114
+
115
  @staticmethod
116
  def generate_caption(model, image, tokenizer, max_length=128):
117
  model.eval()