huzey commited on
Commit
caaf478
1 Parent(s): b26090c

update gpu

Browse files
Files changed (1) hide show
  1. app.py +44 -67
app.py CHANGED
@@ -3,7 +3,6 @@ from einops import rearrange
3
  import torch
4
  import torch.nn.functional as F
5
  from PIL import Image
6
- import torchvision.transforms as transforms
7
  from torch import nn
8
  import numpy as np
9
  import os
@@ -17,6 +16,15 @@ USE_CUDA = torch.cuda.is_available()
17
 
18
  print("CUDA is available:", USE_CUDA)
19
 
 
 
 
 
 
 
 
 
 
20
  class MobileSAM(nn.Module):
21
  def __init__(self, **kwargs):
22
  super().__init__(**kwargs)
@@ -139,19 +147,12 @@ mobilesam = MobileSAM()
139
 
140
  def image_mobilesam_feature(
141
  images,
142
- resolution=(1024, 1024),
143
  node_type="block",
144
  layer=-1,
145
  ):
146
 
147
- transform = transforms.Compose(
148
- [
149
- transforms.Resize(resolution),
150
- transforms.ToTensor(),
151
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
152
- ]
153
- )
154
-
155
 
156
  feat_extractor = mobilesam
157
  if USE_CUDA:
@@ -159,12 +160,9 @@ def image_mobilesam_feature(
159
 
160
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
161
  outputs = []
162
- for i, image in enumerate(images):
163
- torch_image = transform(image)
164
- if USE_CUDA:
165
- torch_image = torch_image.cuda()
166
  attn_output, mlp_output, block_output = feat_extractor(
167
- torch_image.unsqueeze(0)
168
  )
169
  out_dict = {
170
  "attn": attn_output,
@@ -251,18 +249,12 @@ sam = SAM()
251
 
252
  def image_sam_feature(
253
  images,
254
- resolution=(1024, 1024),
255
  node_type="block",
256
  layer=-1,
257
  ):
258
 
259
- transform = transforms.Compose(
260
- [
261
- transforms.Resize(resolution),
262
- transforms.ToTensor(),
263
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
264
- ]
265
- )
266
 
267
  feat_extractor = sam
268
  if USE_CUDA:
@@ -270,12 +262,9 @@ def image_sam_feature(
270
 
271
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
272
  outputs = []
273
- for i, image in enumerate(images):
274
- torch_image = transform(image)
275
- if USE_CUDA:
276
- torch_image = torch_image.cuda()
277
  attn_output, mlp_output, block_output = feat_extractor(
278
- torch_image.unsqueeze(0)
279
  )
280
  out_dict = {
281
  "attn": attn_output,
@@ -338,27 +327,20 @@ class DiNOv2(torch.nn.Module):
338
 
339
  dinov2 = DiNOv2()
340
 
341
- def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
342
 
343
- transform = transforms.Compose(
344
- [
345
- transforms.Resize(resolution),
346
- transforms.ToTensor(),
347
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
348
- ]
349
- )
350
 
351
  feat_extractor = dinov2
352
  if USE_CUDA:
353
  feat_extractor = feat_extractor.cuda()
354
 
 
355
  outputs = []
356
- for i, image in enumerate(images):
357
- torch_image = transform(image)
358
- if USE_CUDA:
359
- torch_image = torch_image.cuda()
360
  attn_output, mlp_output, block_output = feat_extractor(
361
- torch_image.unsqueeze(0)
362
  )
363
  out_dict = {
364
  "attn": attn_output,
@@ -443,33 +425,20 @@ class CLIP(torch.nn.Module):
443
  clip = CLIP()
444
 
445
  def image_clip_feature(
446
- images, resolution=(224, 224), node_type="block", layer=-1
447
  ):
448
- if isinstance(images, list):
449
- assert isinstance(images[0], Image.Image), "Input must be a list of PIL images."
450
- else:
451
- assert isinstance(images, Image.Image), "Input must be a PIL image."
452
- images = [images]
453
-
454
- transform = transforms.Compose(
455
- [
456
- transforms.Resize(resolution),
457
- transforms.ToTensor(),
458
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
459
- ]
460
- )
461
 
462
  feat_extractor = clip
463
  if USE_CUDA:
464
  feat_extractor = feat_extractor.cuda()
465
 
 
466
  outputs = []
467
- for i, image in enumerate(images):
468
- torch_image = transform(image)
469
- if USE_CUDA:
470
- torch_image = torch_image.cuda()
471
  attn_output, mlp_output, block_output = feat_extractor(
472
- torch_image.unsqueeze(0)
473
  )
474
  out_dict = {
475
  "attn": attn_output,
@@ -527,27 +496,35 @@ def compute_hash(*args, **kwargs):
527
 
528
 
529
  @spaces.GPU(duration=30)
530
- def run_model_on_image(image, model_name="sam", node_type="block", layer=-1):
531
  global USE_CUDA
532
  USE_CUDA = True
533
 
534
  if model_name == "SAM(sam_vit_b)":
535
  if not USE_CUDA:
536
  gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
537
- result = image_sam_feature([image], node_type=node_type, layer=layer)
538
  elif model_name == 'MobileSAM':
539
- result = image_mobilesam_feature([image], node_type=node_type, layer=layer)
540
  elif model_name == "DiNO(dinov2_vitb14_reg)":
541
- result = image_dino_feature([image], node_type=node_type, layer=layer)
542
  elif model_name == "CLIP(openai/clip-vit-base-patch16)":
543
- result = image_clip_feature([image], node_type=node_type, layer=layer)
544
  else:
545
  raise ValueError(f"Model {model_name} not supported.")
546
 
547
  USE_CUDA = False
548
  return result
549
 
550
- def extract_features(images, model_name="sam", node_type="block", layer=-1):
 
 
 
 
 
 
 
 
551
  # Compute the cache key
552
  cache_key = compute_hash(images, model_name, node_type, layer)
553
 
@@ -556,7 +533,7 @@ def extract_features(images, model_name="sam", node_type="block", layer=-1):
556
  print("Cache hit!")
557
  return cache[cache_key]
558
 
559
- result = run_model_on_image(images[0], model_name=model_name, node_type=node_type, layer=layer)
560
 
561
  # Store the result in the cache
562
  cache[cache_key] = result
 
3
  import torch
4
  import torch.nn.functional as F
5
  from PIL import Image
 
6
  from torch import nn
7
  import numpy as np
8
  import os
 
16
 
17
  print("CUDA is available:", USE_CUDA)
18
 
19
+ def transform_images(images, resolution=(1024, 1024)):
20
+ images = [image.convert("RGB").resize(resolution) for image in images]
21
+ # Convert to torch tensor
22
+ images = [torch.tensor(np.array(image).transpose(2, 0, 1)).float() / 255 for image in images]
23
+ # Normalize
24
+ images = [(image - 0.5) / 0.5 for image in images]
25
+ images = torch.stack(images)
26
+ return images
27
+
28
  class MobileSAM(nn.Module):
29
  def __init__(self, **kwargs):
30
  super().__init__(**kwargs)
 
147
 
148
  def image_mobilesam_feature(
149
  images,
 
150
  node_type="block",
151
  layer=-1,
152
  ):
153
 
154
+ if USE_CUDA:
155
+ images = images.cuda()
 
 
 
 
 
 
156
 
157
  feat_extractor = mobilesam
158
  if USE_CUDA:
 
160
 
161
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
162
  outputs = []
163
+ for i in range(images.shape[0]):
 
 
 
164
  attn_output, mlp_output, block_output = feat_extractor(
165
+ images[i].unsqueeze(0)
166
  )
167
  out_dict = {
168
  "attn": attn_output,
 
249
 
250
  def image_sam_feature(
251
  images,
 
252
  node_type="block",
253
  layer=-1,
254
  ):
255
 
256
+ if USE_CUDA:
257
+ images = images.cuda()
 
 
 
 
 
258
 
259
  feat_extractor = sam
260
  if USE_CUDA:
 
262
 
263
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
264
  outputs = []
265
+ for i in range(images.shape[0]):
 
 
 
266
  attn_output, mlp_output, block_output = feat_extractor(
267
+ images[i].unsqueeze(0)
268
  )
269
  out_dict = {
270
  "attn": attn_output,
 
327
 
328
  dinov2 = DiNOv2()
329
 
330
+ def image_dino_feature(images, node_type="block", layer=-1):
331
 
332
+ if USE_CUDA:
333
+ images = images.cuda()
 
 
 
 
 
334
 
335
  feat_extractor = dinov2
336
  if USE_CUDA:
337
  feat_extractor = feat_extractor.cuda()
338
 
339
+ # attn_outputs, mlp_outputs, block_outputs = [], [], []
340
  outputs = []
341
+ for i in range(images.shape[0]):
 
 
 
342
  attn_output, mlp_output, block_output = feat_extractor(
343
+ images[i].unsqueeze(0)
344
  )
345
  out_dict = {
346
  "attn": attn_output,
 
425
  clip = CLIP()
426
 
427
  def image_clip_feature(
428
+ images, node_type="block", layer=-1
429
  ):
430
+ if USE_CUDA:
431
+ images = images.cuda()
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  feat_extractor = clip
434
  if USE_CUDA:
435
  feat_extractor = feat_extractor.cuda()
436
 
437
+ # attn_outputs, mlp_outputs, block_outputs = [], [], []
438
  outputs = []
439
+ for i in range(images.shape[0]):
 
 
 
440
  attn_output, mlp_output, block_output = feat_extractor(
441
+ images[i].unsqueeze(0)
442
  )
443
  out_dict = {
444
  "attn": attn_output,
 
496
 
497
 
498
  @spaces.GPU(duration=30)
499
+ def run_model_on_image(images, model_name="sam", node_type="block", layer=-1):
500
  global USE_CUDA
501
  USE_CUDA = True
502
 
503
  if model_name == "SAM(sam_vit_b)":
504
  if not USE_CUDA:
505
  gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
506
+ result = image_sam_feature(images, node_type=node_type, layer=layer)
507
  elif model_name == 'MobileSAM':
508
+ result = image_mobilesam_feature(images, node_type=node_type, layer=layer)
509
  elif model_name == "DiNO(dinov2_vitb14_reg)":
510
+ result = image_dino_feature(images, node_type=node_type, layer=layer)
511
  elif model_name == "CLIP(openai/clip-vit-base-patch16)":
512
+ result = image_clip_feature(images, node_type=node_type, layer=layer)
513
  else:
514
  raise ValueError(f"Model {model_name} not supported.")
515
 
516
  USE_CUDA = False
517
  return result
518
 
519
+ def extract_features(images, model_name="mobilesam", node_type="block", layer=-1):
520
+ resolution_dict = {
521
+ "mobilesam": (1024, 1024),
522
+ "sam(sam_vit_b)": (1024, 1024),
523
+ "dinov2(dinov2_vitb14_reg)": (448, 448),
524
+ "clip(openai/clip-vit-base-patch16)": (224, 224),
525
+ }
526
+ images = transform_images(images, resolution=resolution_dict[model_name])
527
+
528
  # Compute the cache key
529
  cache_key = compute_hash(images, model_name, node_type, layer)
530
 
 
533
  print("Cache hit!")
534
  return cache[cache_key]
535
 
536
+ result = run_model_on_image(images, model_name=model_name, node_type=node_type, layer=layer)
537
 
538
  # Store the result in the cache
539
  cache[cache_key] = result