huzey commited on
Commit
2b98806
1 Parent(s): c7e7e12

initial commit

Browse files
README.md CHANGED
@@ -1,13 +1,44 @@
1
- ---
2
- title: Ncut Pytorch
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ Documentation [https://ncut-pytorch.readthedocs.io/](https://ncut-pytorch.readthedocs.io/)
5
+
6
+
7
+ ## NCUT: Nyström Normalized Cut
8
+
9
+ **Normalized Cut**, aka. spectral clustering, is a graphical method to analyze data grouping in the affinity eigenvector space. It has been widely used for unsupervised segmentation in the 2000s.
10
+
11
+ **Nyström Normalized Cut**, is a new approximation algorithm developed for large-scale graph cuts, a large-graph of million nodes can be processed in under 10s (cpu) or 2s (gpu).
12
+
13
+ ## Gallery
14
+ TODO
15
+
16
+ ## Installation
17
+
18
+ PyPI install, our package is based on [PyTorch](https://pytorch.org/get-started/locally/), presuming you already have PyTorch installed
19
+
20
+ ```shell
21
+ pip install ncut-pytorch
22
+ ```
23
+
24
+ [Install PyTorch](https://pytorch.org/get-started/locally/) if you haven't
25
+ ```shell
26
+ pip install torch
27
+ ```
28
+ ## Why NCUT
29
+
30
+ Normalized cut offers two advantages:
31
+
32
+ 1. soft-cluster assignments as eigenvectors
33
+
34
+ 2. hierarchical clustering by varying the number of eigenvectors
35
+
36
+ Please see [NCUT and t-SNE/UMAP](compare.md) for a full comparison.
37
+
38
+
39
+ > paper in prep, Yang 2024
40
+ >
41
+ > AlignedCut: Visual Concepts Discovery on Brain-Guided Universal Feature Space, Huzheng Yang, James Gee\*, Jianbo Shi\*, 2024
42
+ >
43
+ > Normalized Cuts and Image Segmentation, Jianbo Shi and Jitendra Malik, 2000
44
+ >
app.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ from typing import Optional, Tuple
3
+ from einops import rearrange
4
+ import torch
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ from torch import nn
8
+ import numpy as np
9
+
10
+ import gradio as gr
11
+
12
+
13
+ class SAM(torch.nn.Module):
14
+ def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
15
+ super().__init__(**kwargs)
16
+ from segment_anything import sam_model_registry, SamPredictor
17
+ from segment_anything.modeling.sam import Sam
18
+
19
+ sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
20
+
21
+ from segment_anything.modeling.image_encoder import (
22
+ window_partition,
23
+ window_unpartition,
24
+ )
25
+
26
+ def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ shortcut = x
28
+ x = self.norm1(x)
29
+ # Window partition
30
+ if self.window_size > 0:
31
+ H, W = x.shape[1], x.shape[2]
32
+ x, pad_hw = window_partition(x, self.window_size)
33
+
34
+ x = self.attn(x)
35
+ # Reverse window partition
36
+ if self.window_size > 0:
37
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
38
+ self.attn_output = x.clone()
39
+
40
+ x = shortcut + x
41
+ mlp_outout = self.mlp(self.norm2(x))
42
+ self.mlp_output = mlp_outout.clone()
43
+ x = x + mlp_outout
44
+ self.block_output = x.clone()
45
+
46
+ return x
47
+
48
+ setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
49
+
50
+ self.image_encoder = sam.image_encoder
51
+ self.image_encoder.eval()
52
+ # self.image_encoder = self.image_encoder.cuda()
53
+
54
+ @torch.no_grad()
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ with torch.no_grad():
57
+ x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
58
+ out = self.image_encoder(x)
59
+
60
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
61
+ for i, blk in enumerate(self.image_encoder.blocks):
62
+ attn_outputs.append(blk.attn_output)
63
+ mlp_outputs.append(blk.mlp_output)
64
+ block_outputs.append(blk.block_output)
65
+ attn_outputs = torch.stack(attn_outputs)
66
+ mlp_outputs = torch.stack(mlp_outputs)
67
+ block_outputs = torch.stack(block_outputs)
68
+ return attn_outputs, mlp_outputs, block_outputs
69
+
70
+
71
+ def image_sam_feature(
72
+ images,
73
+ resolution=(1024, 1024),
74
+ node_type="block",
75
+ layer=-1,
76
+ ):
77
+
78
+ transform = transforms.Compose(
79
+ [
80
+ transforms.Resize(resolution),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
83
+ ]
84
+ )
85
+
86
+ checkpoint = "sam_vit_b_01ec64.pth"
87
+ if not os.path.exists(checkpoint):
88
+ checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
89
+ import requests
90
+ r = requests.get(checkpoint_url)
91
+ with open(checkpoint, 'wb') as f:
92
+ f.write(r.content)
93
+
94
+ feat_extractor = SAM(checkpoint=checkpoint)
95
+
96
+ # attn_outputs, mlp_outputs, block_outputs = [], [], []
97
+ outputs = []
98
+ for i, image in enumerate(images):
99
+ torch_image = transform(image)
100
+ attn_output, mlp_output, block_output = feat_extractor(
101
+ # torch_image.unsqueeze(0).cuda()
102
+ torch_image.unsqueeze(0)
103
+ )
104
+ out_dict = {
105
+ "attn": attn_output,
106
+ "mlp": mlp_output,
107
+ "block": block_output,
108
+ }
109
+ out = out_dict[node_type]
110
+ out = out[layer]
111
+ outputs.append(out.cpu())
112
+ outputs = torch.cat(outputs, dim=0)
113
+ return outputs
114
+
115
+
116
+ class DiNOv2(torch.nn.Module):
117
+ def __init__(self, ver="dinov2_vitb14_reg"):
118
+ super().__init__()
119
+ self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
120
+ self.dinov2.requires_grad_(False)
121
+ self.dinov2.eval()
122
+ # self.dinov2 = self.dinov2.cuda()
123
+
124
+ def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ def attn_residual_func(x):
126
+ return self.ls1(self.attn(self.norm1(x)))
127
+
128
+ def ffn_residual_func(x):
129
+ return self.ls2(self.mlp(self.norm2(x)))
130
+
131
+ attn_output = attn_residual_func(x)
132
+ self.attn_output = attn_output.clone()
133
+ x = x + attn_output
134
+ mlp_output = ffn_residual_func(x)
135
+ self.mlp_output = mlp_output.clone()
136
+ x = x + mlp_output
137
+ block_output = x
138
+ self.block_output = block_output.clone()
139
+ return x
140
+
141
+ setattr(self.dinov2.blocks[0].__class__, "forward", new_block_forward)
142
+
143
+ @torch.no_grad()
144
+ def forward(self, x):
145
+
146
+ out = self.dinov2(x)
147
+
148
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
149
+ for i, blk in enumerate(self.dinov2.blocks):
150
+ attn_outputs.append(blk.attn_output)
151
+ mlp_outputs.append(blk.mlp_output)
152
+ block_outputs.append(blk.block_output)
153
+
154
+ attn_outputs = torch.stack(attn_outputs)
155
+ mlp_outputs = torch.stack(mlp_outputs)
156
+ block_outputs = torch.stack(block_outputs)
157
+ return attn_outputs, mlp_outputs, block_outputs
158
+
159
+
160
+ def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
161
+
162
+ transform = transforms.Compose(
163
+ [
164
+ transforms.Resize(resolution),
165
+ transforms.ToTensor(),
166
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
167
+ ]
168
+ )
169
+
170
+ feat_extractor = DiNOv2()
171
+
172
+ outputs = []
173
+ for i, image in enumerate(images):
174
+ torch_image = transform(image)
175
+ attn_output, mlp_output, block_output = feat_extractor(
176
+ # torch_image.unsqueeze(0).cuda()
177
+ torch_image.unsqueeze(0)
178
+ )
179
+ out_dict = {
180
+ "attn": attn_output,
181
+ "mlp": mlp_output,
182
+ "block": block_output,
183
+ }
184
+ out = out_dict[node_type]
185
+ out = out[layer]
186
+ outputs.append(out.cpu())
187
+ outputs = torch.cat(outputs, dim=0)
188
+ outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
189
+ return outputs
190
+
191
+
192
+ class CLIP(torch.nn.Module):
193
+ def __init__(self):
194
+ super().__init__()
195
+
196
+ from transformers import CLIPProcessor, CLIPModel
197
+
198
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
199
+ # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
200
+ self.model = model.eval()
201
+ # self.model = self.model.cuda()
202
+
203
+ def new_forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ attention_mask: torch.Tensor,
207
+ causal_attention_mask: torch.Tensor,
208
+ output_attentions: Optional[bool] = False,
209
+ ) -> Tuple[torch.FloatTensor]:
210
+
211
+ residual = hidden_states
212
+
213
+ hidden_states = self.layer_norm1(hidden_states)
214
+ hidden_states, attn_weights = self.self_attn(
215
+ hidden_states=hidden_states,
216
+ attention_mask=attention_mask,
217
+ causal_attention_mask=causal_attention_mask,
218
+ output_attentions=output_attentions,
219
+ )
220
+ self.attn_output = hidden_states.clone()
221
+ hidden_states = residual + hidden_states
222
+
223
+ residual = hidden_states
224
+ hidden_states = self.layer_norm2(hidden_states)
225
+ hidden_states = self.mlp(hidden_states)
226
+ self.mlp_output = hidden_states.clone()
227
+
228
+ hidden_states = residual + hidden_states
229
+
230
+ outputs = (hidden_states,)
231
+
232
+ if output_attentions:
233
+ outputs += (attn_weights,)
234
+
235
+ self.block_output = hidden_states.clone()
236
+ return outputs
237
+
238
+ setattr(self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward)
239
+
240
+ @torch.no_grad()
241
+ def forward(self, x):
242
+
243
+ out = self.model.vision_model(x)
244
+
245
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
246
+ for i, blk in enumerate(self.model.vision_model.encoder.layers):
247
+ attn_outputs.append(blk.attn_output)
248
+ mlp_outputs.append(blk.mlp_output)
249
+ block_outputs.append(blk.block_output)
250
+
251
+ attn_outputs = torch.stack(attn_outputs)
252
+ mlp_outputs = torch.stack(mlp_outputs)
253
+ block_outputs = torch.stack(block_outputs)
254
+ return attn_outputs, mlp_outputs, block_outputs
255
+
256
+
257
+ def image_clip_feature(
258
+ images, resolution=(224, 224), node_type="block", layer=-1
259
+ ):
260
+ if isinstance(images, list):
261
+ assert isinstance(images[0], Image.Image), "Input must be a list of PIL images."
262
+ else:
263
+ assert isinstance(images, Image.Image), "Input must be a PIL image."
264
+ images = [images]
265
+
266
+ transform = transforms.Compose(
267
+ [
268
+ transforms.Resize(resolution),
269
+ transforms.ToTensor(),
270
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
271
+ ]
272
+ )
273
+
274
+ feat_extractor = CLIP()
275
+
276
+ outputs = []
277
+ for i, image in enumerate(images):
278
+ torch_image = transform(image)
279
+ attn_output, mlp_output, block_output = feat_extractor(
280
+ # torch_image.unsqueeze(0).cuda()
281
+ torch_image.unsqueeze(0)
282
+ )
283
+ out_dict = {
284
+ "attn": attn_output,
285
+ "mlp": mlp_output,
286
+ "block": block_output,
287
+ }
288
+ out = out_dict[node_type]
289
+ out = out[layer]
290
+ outputs.append(out.cpu())
291
+ outputs = torch.cat(outputs, dim=0)
292
+ return outputs
293
+
294
+
295
+
296
+ def extract_features(images, model_name="sam", node_type="block", layer=-1):
297
+ if model_name == "SAM(sam_vit_b)":
298
+ return image_sam_feature(images, node_type=node_type, layer=layer)
299
+ elif model_name == "DiNO(dinov2_vitb14_reg)":
300
+ return image_dino_feature(images, node_type=node_type, layer=layer)
301
+ elif model_name == "CLIP(openai/clip-vit-base-patch16)":
302
+ return image_clip_feature(images, node_type=node_type, layer=layer)
303
+ else:
304
+ raise ValueError(f"Model {model_name} not supported.")
305
+
306
+
307
+ def compute_ncut(
308
+ features,
309
+ num_eig=100,
310
+ num_sample_ncut=10000,
311
+ affinity_focal_gamma=0.3,
312
+ knn_ncut=10,
313
+ knn_tsne=10,
314
+ num_sample_tsne=1000,
315
+ perplexity=500,
316
+ ):
317
+ from ncut_pytorch import NCUT, rgb_from_tsne_3d
318
+
319
+ eigvecs, eigvals = NCUT(
320
+ num_eig=num_eig,
321
+ num_sample=num_sample_ncut,
322
+ # device="cuda:0",
323
+ affinity_focal_gamma=affinity_focal_gamma,
324
+ knn=knn_ncut,
325
+ ).fit_transform(features.reshape(-1, features.shape[-1]))
326
+ X_3d, rgb = rgb_from_tsne_3d(
327
+ eigvecs,
328
+ num_sample=num_sample_tsne,
329
+ perplexity=perplexity,
330
+ knn=knn_tsne,
331
+ )
332
+ rgb = rgb.reshape(features.shape[:3] + (3,))
333
+ return rgb
334
+
335
+
336
+ def dont_use_too_much_green(image_rgb):
337
+ # make sure the foval 40% of the image is red leading
338
+ x1, x2 = int(image_rgb.shape[1] * 0.3), int(image_rgb.shape[1] * 0.7)
339
+ y1, y2 = int(image_rgb.shape[2] * 0.3), int(image_rgb.shape[2] * 0.7)
340
+ sum_values = image_rgb[:, x1:x2, y1:y2].mean((0, 1, 2))
341
+ sorted_indices = sum_values.argsort(descending=True)
342
+ image_rgb = image_rgb[:, :, :, sorted_indices]
343
+ return image_rgb
344
+
345
+
346
+ def to_pil_images(images):
347
+ return [
348
+ Image.fromarray((image * 255).cpu().numpy().astype(np.uint8)).resize((256, 256), Image.NEAREST)
349
+ for image in images
350
+ ]
351
+
352
+
353
+ def main_fn(
354
+ images,
355
+ model_name="SAM(sam_vit_b)",
356
+ node_type="block",
357
+ layer=-1,
358
+ num_eig=100,
359
+ affinity_focal_gamma=0.3,
360
+ num_sample_ncut=10000,
361
+ knn_ncut=10,
362
+ num_sample_tsne=1000,
363
+ knn_tsne=10,
364
+ perplexity=500,
365
+ ):
366
+ if perplexity >= num_sample_tsne:
367
+ # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
368
+ gr.Warning("Perplexity must be less than the number of samples for t-SNE.\n" f"Setting perplexity to {num_sample_tsne-1}.")
369
+ perplexity = num_sample_tsne - 1
370
+
371
+ images = [image[0] for image in images]
372
+ features = extract_features(
373
+ images, model_name=model_name, node_type=node_type, layer=layer
374
+ )
375
+ rgb = compute_ncut(
376
+ features,
377
+ num_eig=num_eig,
378
+ num_sample_ncut=num_sample_ncut,
379
+ affinity_focal_gamma=affinity_focal_gamma,
380
+ knn_ncut=knn_ncut,
381
+ knn_tsne=knn_tsne,
382
+ num_sample_tsne=num_sample_tsne,
383
+ perplexity=perplexity,
384
+ )
385
+ rgb = dont_use_too_much_green(rgb)
386
+ return to_pil_images(rgb)
387
+
388
+
389
+ default_images = ['/workspace/output/gradio/image_0.jpg', '/workspace/output/gradio/image_1.jpg', '/workspace/output/gradio/image_2.jpg', '/workspace/output/gradio/image_3.jpg', '/workspace/output/gradio/image_4.jpg', '/workspace/output/gradio/image_5.jpg']
390
+ default_outputs = ['/workspace/output/gradio/ncut_0.jpg', '/workspace/output/gradio/ncut_1.jpg', '/workspace/output/gradio/ncut_2.jpg', '/workspace/output/gradio/ncut_3.jpg', '/workspace/output/gradio/ncut_4.jpg', '/workspace/output/gradio/ncut_5.jpg']
391
+
392
+ demo = gr.Interface(
393
+ main_fn,
394
+ [
395
+ gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
396
+ gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
397
+ gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
398
+ gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
399
+ gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
400
+ gr.Slider(0.01, 1, step=0.01, label="Affinity focal gamma", value=0.3, elem_id="affinity_focal_gamma", info="decrease for more aggressive cleaning on the affinity matrix"),
401
+ ],
402
+ gr.Gallery(value=default_outputs, label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto"),
403
+ additional_inputs=[
404
+ gr.Slider(100, 30000, step=100, label="num_sample (NCUT)", value=10000, elem_id="num_sample_ncut", info="for Nyström approximation"),
405
+ gr.Slider(1, 100, step=1, label="KNN (NCUT)", value=10, elem_id="knn_ncut", info="for Nyström approximation"),
406
+ gr.Slider(100, 10000, step=100, label="num_sample (t-SNE)", value=1000, elem_id="num_sample_tsne", info="for Nyström approximation. Adding will slow down t-SNE quite a lot"),
407
+ gr.Slider(1, 100, step=1, label="KNN (t-SNE)", value=10, elem_id="knn_tsne", info="for Nyström approximation"),
408
+ gr.Slider(10, 1000, step=10, label="Perplexity (t-SNE)", value=500, elem_id="perplexity", info="for t-SNE"),
409
+
410
+ ]
411
+ )
412
+
413
+ demo.launch(share=True)
414
+
415
+ # %%
416
+
417
+
418
+ # # %%
419
+ # from ncut_pytorch import NCUT, rgb_from_tsne_3d
420
+
421
+ # i_layer = -1
422
+ # inp = block_outputs[i_layer]
423
+ # eigvecs, eigvals = NCUT(
424
+ # num_eig=1000, num_sample=10000, device="cuda:0", affinity_focal_gamma=0.3, knn=10
425
+ # ).fit_transform(inp.reshape(-1, inp.shape[-1]))
426
+ # print(eigvecs.shape, eigvals.shape)
427
+ # # %%
428
+ # X_3d, rgb = rgb_from_tsne_3d(
429
+ # eigvecs[:, :100], num_sample=1000, perplexity=500, knn=10, seed=42
430
+ # )
431
+ # # %%
432
+ # image_rgb = rgb.reshape(*inp.shape[:-1], 3)
433
+ # # make sure the foval 20% of the image is red leading
434
+ # x1, x2 = int(image_rgb.shape[1] * 0.4), int(image_rgb.shape[1] * 0.6)
435
+ # y1, y2 = int(image_rgb.shape[2] * 0.4), int(image_rgb.shape[2] * 0.6)
436
+ # sum_values = image_rgb[:, x1:x2, y1:y2].mean((0, 1, 2))
437
+ # sorted_indices = sum_values.argsort(descending=True)
438
+ # image_rgb = image_rgb[:, :, :, sorted_indices]
439
+
440
+ # import matplotlib.pyplot as plt
441
+
442
+ # fig, axes = plt.subplots(2, 3, figsize=(15, 10))
443
+ # for i, ax in enumerate(axes.flat):
444
+ # ax.imshow(image_rgb[i])
445
+ # ax.axis("off")
446
+
447
+ # %%
448
+ save_dir = "/workspace/output/gradio"
449
+ import os
450
+
451
+ os.makedirs(save_dir, exist_ok=True)
452
+
453
+ images = ['/workspace/guitars/lespual1.png', '/workspace/guitars/lespual2.png', '/workspace/guitars/lespual3.png', '/workspace/guitars/lespual4.png', '/workspace/guitars/lespual5.png', '/workspace/guitars/acoustic1.png']
454
+ images = [Image.open(image).convert("RGB") for image in images]
455
+ for i, image in enumerate(images):
456
+ image = image.resize((512, 512))
457
+ image.save(os.path.join(save_dir, f"image_{i}.jpg"), "JPEG", quality=70)
458
+ # %%
459
+ images = [(image, '') for image in images]
460
+ image_rbg = main_fn(images)
461
+ # %%
462
+ for i, rgb in enumerate(image_rbg):
463
+ rgb = rgb.resize((512, 512), Image.NEAREST)
464
+ rgb.save(os.path.join(save_dir, f"ncut_{i}.jpg"), "JPEG", quality=70)
465
+ # %%
466
+ for i, rgb in enumerate(image_rgb):
467
+ rgb = Image.fromarray((rgb * 255).cpu().numpy().astype(np.uint8))
468
+ rgb.save(os.path.join(save_dir, f"ncut_{i}.png"))
469
+ # %%
470
+ %%
images/image_0.jpg ADDED
images/image_1.jpg ADDED
images/image_2.jpg ADDED
images/image_3.jpg ADDED
images/image_4.jpg ADDED
images/image_5.jpg ADDED
images/ncut_0.jpg ADDED
images/ncut_1.jpg ADDED
images/ncut_2.jpg ADDED
images/ncut_3.jpg ADDED
images/ncut_4.jpg ADDED
images/ncut_5.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ncut-pytorch
2
+ transformers
3
+ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git