abhicake commited on
Commit
2d7ccaa
1 Parent(s): 742337e

Upload 3 files

Browse files
Files changed (3) hide show
  1. MyPipe.py +76 -0
  2. preprocessor_config.json +23 -0
  3. utilities.py +25 -0
MyPipe.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+ from transformers import Pipeline
6
+ from transformers.image_utils import load_image
7
+ from skimage import io
8
+ from PIL import Image
9
+
10
+ class RMBGPipe(Pipeline):
11
+ def __init__(self,**kwargs):
12
+ Pipeline.__init__(self,**kwargs)
13
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ self.model.to(self.device)
15
+ self.model.eval()
16
+
17
+ def _sanitize_parameters(self, **kwargs):
18
+ # parse parameters
19
+ preprocess_kwargs = {}
20
+ postprocess_kwargs = {}
21
+ if "model_input_size" in kwargs :
22
+ preprocess_kwargs["model_input_size"] = kwargs["model_input_size"]
23
+ if "return_mask" in kwargs:
24
+ postprocess_kwargs["return_mask"] = kwargs["return_mask"]
25
+ return preprocess_kwargs, {}, postprocess_kwargs
26
+
27
+ def preprocess(self,input_image,model_input_size: list=[1024,1024]):
28
+ # preprocess the input
29
+ orig_im = load_image(input_image)
30
+ orig_im = np.array(orig_im)
31
+ orig_im_size = orig_im.shape[0:2]
32
+ preprocessed_image = self.preprocess_image(orig_im, model_input_size).to(self.device)
33
+ inputs = {
34
+ "preprocessed_image":preprocessed_image,
35
+ "orig_im_size":orig_im_size,
36
+ "input_image" : input_image
37
+ }
38
+ return inputs
39
+
40
+ def _forward(self,inputs):
41
+ result = self.model(inputs.pop("preprocessed_image"))
42
+ inputs["result"] = result
43
+ return inputs
44
+
45
+ def postprocess(self,inputs,return_mask:bool=False ):
46
+ result = inputs.pop("result")
47
+ orig_im_size = inputs.pop("orig_im_size")
48
+ input_image = inputs.pop("input_image")
49
+ result_image = self.postprocess_image(result[0][0], orig_im_size)
50
+ pil_im = Image.fromarray(result_image)
51
+ if return_mask ==True :
52
+ return pil_im
53
+ no_bg_image = Image.new("RGBA", pil_im.size, (0,0,0,0))
54
+ input_image = load_image(input_image)
55
+ no_bg_image.paste(input_image, mask=pil_im)
56
+ return no_bg_image
57
+
58
+ # utilities functions
59
+ def preprocess_image(self,im: np.ndarray, model_input_size: list=[1024,1024]) -> torch.Tensor:
60
+ # same as utilities.py with minor modification
61
+ if len(im.shape) < 3:
62
+ im = im[:, :, np.newaxis]
63
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
64
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear')
65
+ image = torch.divide(im_tensor,255.0)
66
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
67
+ return image
68
+
69
+ def postprocess_image(self,result: torch.Tensor, im_size: list)-> np.ndarray:
70
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
71
+ ma = torch.max(result)
72
+ mi = torch.min(result)
73
+ result = (result-mi)/(ma-mi)
74
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
75
+ im_array = np.squeeze(im_array)
76
+ return im_array
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_pad": false,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "feature_extractor_type": "ImageFeatureExtractor",
12
+ "image_std": [
13
+ 1,
14
+ 1,
15
+ 1
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "width": 1024,
21
+ "height": 1024
22
+ }
23
+ }
utilities.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ import numpy as np
5
+
6
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
7
+ if len(im.shape) < 3:
8
+ im = im[:, :, np.newaxis]
9
+ # orig_im_size=im.shape[0:2]
10
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
11
+ im_tensor = F.interpolate(torch.unsqueeze(im_tensor,0), size=model_input_size, mode='bilinear').type(torch.uint8)
12
+ image = torch.divide(im_tensor,255.0)
13
+ image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
14
+ return image
15
+
16
+
17
+ def postprocess_image(result: torch.Tensor, im_size: list)-> np.ndarray:
18
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear') ,0)
19
+ ma = torch.max(result)
20
+ mi = torch.min(result)
21
+ result = (result-mi)/(ma-mi)
22
+ im_array = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
23
+ im_array = np.squeeze(im_array)
24
+ return im_array
25
+