fancyfeast commited on
Commit
b72bef3
1 Parent(s): df9e86f

Prepare images correctly

Browse files
Files changed (1) hide show
  1. app.py +32 -2
app.py CHANGED
@@ -4,15 +4,45 @@ import huggingface_hub
4
  from PIL import Image
5
  import torch.amp.autocast_mode
6
  from pathlib import Path
 
 
7
 
8
 
9
  MODEL_REPO = "fancyfeast/joytag"
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @torch.no_grad()
13
  def predict(image: Image.Image):
14
- with torch.amp.autocast_mode.autocast('cuda', enabled=True):
15
- preds = model(image)
 
 
 
 
 
16
  tag_preds = preds['tags'].sigmoid().cpu()
17
 
18
  return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}
 
4
  from PIL import Image
5
  import torch.amp.autocast_mode
6
  from pathlib import Path
7
+ import torch
8
+ import torchvision.transforms.functional as TVF
9
 
10
 
11
  MODEL_REPO = "fancyfeast/joytag"
12
 
13
 
14
+ def prepare_image(image: Image.Image, target_size: int) -> torch.Tensor:
15
+ # Pad image to square
16
+ image_shape = image.size
17
+ max_dim = max(image_shape)
18
+ pad_left = (max_dim - image_shape[0]) // 2
19
+ pad_top = (max_dim - image_shape[1]) // 2
20
+
21
+ padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
22
+ padded_image.paste(image, (pad_left, pad_top))
23
+
24
+ # Resize image
25
+ if max_dim != target_size:
26
+ padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
27
+
28
+ # Convert to tensor
29
+ image_tensor = TVF.pil_to_tensor(padded_image) / 255.0
30
+
31
+ # Normalize
32
+ image_tensor = TVF.normalize(image_tensor, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
33
+
34
+ return image_tensor
35
+
36
+
37
  @torch.no_grad()
38
  def predict(image: Image.Image):
39
+ image_tensor = prepare_image(image, model.image_size)
40
+ batch = {
41
+ 'image': image_tensor.unsqueeze(0),
42
+ }
43
+
44
+ with torch.amp.autocast_mode.autocast('cpu', enabled=True):
45
+ preds = model(batch)
46
  tag_preds = preds['tags'].sigmoid().cpu()
47
 
48
  return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))}