huzey commited on
Commit
3d40e53
1 Parent(s): e2b7cb4

add progress bar

Browse files
Files changed (2) hide show
  1. app.py +108 -30
  2. app_text.py +9 -1
app.py CHANGED
@@ -77,7 +77,9 @@ def compute_ncut(
77
  min_dist=0.1,
78
  sampling_method="fps",
79
  metric="cosine",
 
80
  ):
 
81
  logging_str = ""
82
 
83
  num_nodes = np.prod(features.shape[:-1])
@@ -88,6 +90,7 @@ def compute_ncut(
88
  logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
89
 
90
  start = time.time()
 
91
  eigvecs, eigvals = NCUT(
92
  num_eig=num_eig,
93
  num_sample=num_sample_ncut,
@@ -102,6 +105,7 @@ def compute_ncut(
102
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
103
 
104
  start = time.time()
 
105
  _, rgb = eigenvector_to_rgb(
106
  eigvecs,
107
  method=embedding_method,
@@ -249,15 +253,34 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
249
  blended = (1 - opacity1) * image + opacity2 * heatmap
250
  return blended.astype(np.uint8)
251
 
252
- def make_cluster_plot(eigvecs, images, h=64, w=64):
 
 
 
 
253
  from ncut_pytorch.ncut_pytorch import farthest_point_sampling
254
  magnitude = torch.norm(eigvecs, dim=-1)
255
- p = 0.5
256
  top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
257
- num_samples = 50
 
 
258
  fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
259
  fps_idx = top_p_idx[fps_idx]
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  # downsample to 256x256
262
  images = F.interpolate(images, (256, 256), mode="bilinear")
263
  images = images.cpu().numpy()
@@ -269,29 +292,57 @@ def make_cluster_plot(eigvecs, images, h=64, w=64):
269
  # sort the fps_idx by the mean of the heatmap
270
  fps_heatmaps = {}
271
  sort_values = []
 
272
  for _, idx in enumerate(fps_idx):
273
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
274
- eigvecs = eigvecs.to(device)
275
  heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
 
 
 
 
 
 
 
 
 
276
  heatmap = heatmap.reshape(-1, h, w)
277
- mask = (heatmap > 0.5).float()
 
 
 
 
278
  sort_values.append(mask.mean().item())
279
- fps_heatmaps[idx.item()] = heatmap.cpu()
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  fig_images = []
282
  i_cluster = 0
283
- for i_fig in range(10):
 
 
 
284
  fig, axs = plt.subplots(3, 5, figsize=(15, 9))
285
  for ax in axs.flatten():
286
  ax.axis("off")
287
  for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
288
  heatmap = fps_heatmaps[idx.item()]
289
- mask = (heatmap > 0.1).float()
290
- sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
291
  size = (images.shape[1], images.shape[2])
292
  heatmap = apply_reds_colormap(heatmap, size)
293
- for i, image_idx in enumerate(sorted_image_idxs[:3]):
294
- _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
 
 
295
  axs[i, j].imshow(_heatmap)
296
  if i == 0:
297
  axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
@@ -348,6 +399,9 @@ def ncut_run(
348
  lisa_prompt2="",
349
  lisa_prompt3="",
350
  ):
 
 
 
351
  logging_str = ""
352
  if "AlignedThreeModelAttnNodes" == model_name:
353
  # dirty patch for the alignedcut paper
@@ -396,12 +450,16 @@ def ncut_run(
396
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
397
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
398
 
 
 
399
  if recursion:
400
  rgbs = []
401
  recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
402
  inp = features
 
403
  for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
404
  logging_str += f"Recursion #{i+1}\n"
 
405
  rgb, _logging_str, eigvecs = compute_ncut(
406
  inp,
407
  num_eig=n_eigs,
@@ -417,6 +475,7 @@ def ncut_run(
417
  min_dist=min_dist,
418
  sampling_method=sampling_method,
419
  metric="cosine" if i == 0 else recursion_metric,
 
420
  )
421
  logging_str += _logging_str
422
 
@@ -424,6 +483,7 @@ def ncut_run(
424
  if "AlignedThreeModelAttnNodes" == model_name:
425
  # dirty patch for the alignedcut paper
426
  start = time.time()
 
427
  pil_images = []
428
  for i_image in range(rgb.shape[0]):
429
  _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
@@ -442,6 +502,8 @@ def ncut_run(
442
  if old_school_ncut: # individual images
443
  logging_str += "Running NCut for each image independently\n"
444
  rgb = []
 
 
445
  for i_image in range(features.shape[0]):
446
  logging_str += f"Image #{i_image+1}\n"
447
  feature = features[i_image]
@@ -459,6 +521,7 @@ def ncut_run(
459
  n_neighbors=n_neighbors,
460
  min_dist=min_dist,
461
  sampling_method=sampling_method,
 
462
  )
463
  logging_str += _logging_str
464
  rgb.append(_rgb[0])
@@ -486,6 +549,7 @@ def ncut_run(
486
  if "AlignedThreeModelAttnNodes" == model_name:
487
  # dirty patch for the alignedcut paper
488
  start = time.time()
 
489
  pil_images = []
490
  for i_image in range(rgb.shape[0]):
491
  _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
@@ -506,15 +570,18 @@ def ncut_run(
506
 
507
  if not video_output:
508
  start = time.time()
 
 
509
  h, w = features.shape[1], features.shape[2]
510
  if torch.cuda.is_available():
511
  images = images.cuda()
512
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
513
- cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w)
514
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
515
 
516
 
517
  if video_output:
 
518
  video_path = get_random_path()
519
  video_cache.add_video(video_path)
520
  pil_images_to_video(to_pil_images(rgb), video_path)
@@ -526,26 +593,26 @@ def ncut_run(
526
 
527
  def _ncut_run(*args, **kwargs):
528
  n_ret = kwargs.pop("n_ret", 1)
529
- # try:
530
- # if torch.cuda.is_available():
531
- # torch.cuda.empty_cache()
532
 
533
- # ret = ncut_run(*args, **kwargs)
534
 
535
- # if torch.cuda.is_available():
536
- # torch.cuda.empty_cache()
537
 
538
- # ret = list(ret)[:n_ret] + [ret[-1]]
539
- # return ret
540
- # except Exception as e:
541
- # gr.Error(str(e))
542
- # if torch.cuda.is_available():
543
- # torch.cuda.empty_cache()
544
- # return *(None for _ in range(n_ret)), "Error: " + str(e)
545
-
546
- ret = ncut_run(*args, **kwargs)
547
- ret = list(ret)[:n_ret] + [ret[-1]]
548
- return ret
549
 
550
  if USE_HUGGINGFACE_ZEROGPU:
551
  @spaces.GPU(duration=20)
@@ -744,10 +811,15 @@ def run_fn(
744
  n_ret=1,
745
  ):
746
 
 
 
 
 
747
  if images is None:
748
  gr.Warning("No images selected.")
749
  return *(None for _ in range(n_ret)), "No images selected."
750
 
 
751
  video_output = False
752
  if isinstance(images, str):
753
  images = extract_video_frames(images, max_frames=max_frames)
@@ -767,6 +839,7 @@ def run_fn(
767
  images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
768
  images = torch.stack(images)
769
 
 
770
 
771
  if is_lisa:
772
  import subprocess
@@ -976,10 +1049,13 @@ def make_dataset_images_section(advanced=False, is_random=False):
976
  def load_dataset_images(is_advanced, dataset_name, num_images=10,
977
  is_filter=True, filter_by_class_text="0,1,2",
978
  is_random=False, seed=1):
 
 
979
  if is_advanced == "Basic":
980
  gr.Info("Loaded images from Ego-Exo4D")
981
  return default_images
982
  try:
 
983
  dataset = load_dataset(dataset_name, trust_remote_code=True)
984
  key = list(dataset.keys())[0]
985
  dataset = dataset[key]
@@ -990,6 +1066,7 @@ def make_dataset_images_section(advanced=False, is_random=False):
990
  num_images = len(dataset)
991
 
992
  if is_filter:
 
993
  classes = [int(i) for i in filter_by_class_text.split(",")]
994
  labels = np.array(dataset['label'])
995
  unique_labels = np.unique(labels)
@@ -1193,6 +1270,7 @@ with demo:
1193
  with gr.Column(scale=5, min_width=200):
1194
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1195
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
 
1196
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
1197
 
1198
  with gr.Column(scale=5, min_width=200):
 
77
  min_dist=0.1,
78
  sampling_method="fps",
79
  metric="cosine",
80
+ progess_start=0.4,
81
  ):
82
+ progress = gr.Progress()
83
  logging_str = ""
84
 
85
  num_nodes = np.prod(features.shape[:-1])
 
90
  logging_str += f"Number of eigenvectors should be less than half the number of nodes.\n" f"Setting num_eig to {num_nodes // 2 - 1}.\n"
91
 
92
  start = time.time()
93
+ progress(progess_start+0.0, desc="NCut")
94
  eigvecs, eigvals = NCUT(
95
  num_eig=num_eig,
96
  num_sample=num_sample_ncut,
 
105
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
106
 
107
  start = time.time()
108
+ progress(progess_start+0.01, desc="spectral-tSNE")
109
  _, rgb = eigenvector_to_rgb(
110
  eigvecs,
111
  method=embedding_method,
 
253
  blended = (1 - opacity1) * image + opacity2 * heatmap
254
  return blended.astype(np.uint8)
255
 
256
+ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6):
257
+ progress = gr.Progress()
258
+ progress(progess_start, desc="Finding Clusters by FPS")
259
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
260
+ eigvecs = eigvecs.to(device)
261
  from ncut_pytorch.ncut_pytorch import farthest_point_sampling
262
  magnitude = torch.norm(eigvecs, dim=-1)
263
+ p = 0.8
264
  top_p_idx = magnitude.argsort(descending=True)[:int(p * magnitude.shape[0])]
265
+ num_samples = 300
266
+ if num_samples > top_p_idx.shape[0]:
267
+ num_samples = top_p_idx.shape[0]
268
  fps_idx = farthest_point_sampling(eigvecs[top_p_idx], num_samples)
269
  fps_idx = top_p_idx[fps_idx]
270
 
271
+ # fps round 2 on the heatmap
272
+ left = eigvecs[fps_idx, :].clone()
273
+ right = eigvecs.clone()
274
+ left = F.normalize(left, dim=-1)
275
+ right = F.normalize(right, dim=-1)
276
+ heatmap = left @ right.T
277
+ heatmap = F.normalize(heatmap, dim=-1)
278
+ num_samples = 80
279
+ if num_samples > fps_idx.shape[0]:
280
+ num_samples = fps_idx.shape[0]
281
+ r2_fps_idx = farthest_point_sampling(heatmap, num_samples)
282
+ fps_idx = fps_idx[r2_fps_idx]
283
+
284
  # downsample to 256x256
285
  images = F.interpolate(images, (256, 256), mode="bilinear")
286
  images = images.cpu().numpy()
 
292
  # sort the fps_idx by the mean of the heatmap
293
  fps_heatmaps = {}
294
  sort_values = []
295
+ top3_image_idx = {}
296
  for _, idx in enumerate(fps_idx):
 
 
297
  heatmap = F.cosine_similarity(eigvecs, eigvecs[idx][None], dim=-1)
298
+
299
+ # def top_percentile(tensor, p=0.8, max_size=10000):
300
+ # tensor = tensor.clone().flatten()
301
+ # if tensor.shape[0] > max_size:
302
+ # tensor = tensor[torch.randperm(tensor.shape[0])[:max_size]]
303
+ # return tensor.quantile(p)
304
+ # top_p = top_percentile(heatmap, p=0.5)
305
+ top_p = 0.5
306
+
307
  heatmap = heatmap.reshape(-1, h, w)
308
+ mask = (heatmap > top_p).float()
309
+ # take top 3 masks only
310
+ mask_sort_values = mask.mean((1, 2))
311
+ mask_sort_idx = torch.argsort(mask_sort_values, descending=True)
312
+ mask = mask[mask_sort_idx[:3]]
313
  sort_values.append(mask.mean().item())
314
+ # fps_heatmaps[idx.item()] = heatmap.cpu()
315
+ fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:3]].cpu()
316
+ top3_image_idx[idx.item()] = mask_sort_idx[:3]
317
+ # do the sorting
318
+ _sort_idx = torch.tensor(sort_values).argsort(descending=True)
319
+ fps_idx = fps_idx[_sort_idx]
320
+ # reverse the fps_idx
321
+ # fps_idx = fps_idx.flip(0)
322
+ # discard the big clusters
323
+ fps_idx = fps_idx[10:]
324
+ # shuffle the fps_idx
325
+ fps_idx = fps_idx[torch.randperm(fps_idx.shape[0])]
326
 
327
  fig_images = []
328
  i_cluster = 0
329
+ num_plots = 10
330
+ plot_step_float = (1.0 - progess_start) / num_plots
331
+ for i_fig in range(num_plots):
332
+ progress(progess_start + i_fig * plot_step_float, desc="Plotting Clusters")
333
  fig, axs = plt.subplots(3, 5, figsize=(15, 9))
334
  for ax in axs.flatten():
335
  ax.axis("off")
336
  for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
337
  heatmap = fps_heatmaps[idx.item()]
338
+ # mask = (heatmap > 0.1).float()
339
+ # sorted_image_idxs = torch.argsort(mask.mean((1, 2)), descending=True)
340
  size = (images.shape[1], images.shape[2])
341
  heatmap = apply_reds_colormap(heatmap, size)
342
+ # for i, image_idx in enumerate(sorted_image_idxs[:3]):
343
+ for i, image_idx in enumerate(top3_image_idx[idx.item()]):
344
+ # _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[image_idx])
345
+ _heatmap = blend_image_with_heatmap(images[image_idx], heatmap[i])
346
  axs[i, j].imshow(_heatmap)
347
  if i == 0:
348
  axs[i, j].set_title(f"cluster {i_cluster+1}", fontsize=24)
 
399
  lisa_prompt2="",
400
  lisa_prompt3="",
401
  ):
402
+ progress = gr.Progress()
403
+ progress(0.2, desc="Feature Extraction")
404
+
405
  logging_str = ""
406
  if "AlignedThreeModelAttnNodes" == model_name:
407
  # dirty patch for the alignedcut paper
 
450
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
451
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
452
 
453
+ progress(0.4, desc="NCut")
454
+
455
  if recursion:
456
  rgbs = []
457
  recursion_gammas = [recursion_l1_gamma, recursion_l2_gamma, recursion_l3_gamma]
458
  inp = features
459
+ progress_start = 0.4
460
  for i, n_eigs in enumerate([num_eig, recursion_l2_n_eigs, recursion_l3_n_eigs]):
461
  logging_str += f"Recursion #{i+1}\n"
462
+ progress_start += + 0.1 * i
463
  rgb, _logging_str, eigvecs = compute_ncut(
464
  inp,
465
  num_eig=n_eigs,
 
475
  min_dist=min_dist,
476
  sampling_method=sampling_method,
477
  metric="cosine" if i == 0 else recursion_metric,
478
+ progess_start=progress_start,
479
  )
480
  logging_str += _logging_str
481
 
 
483
  if "AlignedThreeModelAttnNodes" == model_name:
484
  # dirty patch for the alignedcut paper
485
  start = time.time()
486
+ progress(progress_start + 0.09, desc=f"Plotting Recursion {i+1}")
487
  pil_images = []
488
  for i_image in range(rgb.shape[0]):
489
  _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
 
502
  if old_school_ncut: # individual images
503
  logging_str += "Running NCut for each image independently\n"
504
  rgb = []
505
+ progress_start = 0.4
506
+ step_float = 0.6 / features.shape[0]
507
  for i_image in range(features.shape[0]):
508
  logging_str += f"Image #{i_image+1}\n"
509
  feature = features[i_image]
 
521
  n_neighbors=n_neighbors,
522
  min_dist=min_dist,
523
  sampling_method=sampling_method,
524
+ progess_start=progress_start+step_float*i_image,
525
  )
526
  logging_str += _logging_str
527
  rgb.append(_rgb[0])
 
549
  if "AlignedThreeModelAttnNodes" == model_name:
550
  # dirty patch for the alignedcut paper
551
  start = time.time()
552
+ progress(0.6, desc="Plotting")
553
  pil_images = []
554
  for i_image in range(rgb.shape[0]):
555
  _im = plot_one_image_36_grid(images[i_image], rgb[i_image])
 
570
 
571
  if not video_output:
572
  start = time.time()
573
+ progress_start = 0.6
574
+ progress(progress_start, desc="Plotting Clusters")
575
  h, w = features.shape[1], features.shape[2]
576
  if torch.cuda.is_available():
577
  images = images.cuda()
578
  _images = reverse_transform_image(images, stablediffusion="stable" in model_name.lower())
579
+ cluster_images = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start)
580
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
581
 
582
 
583
  if video_output:
584
+ progress(0.8, desc="Saving Video")
585
  video_path = get_random_path()
586
  video_cache.add_video(video_path)
587
  pil_images_to_video(to_pil_images(rgb), video_path)
 
593
 
594
  def _ncut_run(*args, **kwargs):
595
  n_ret = kwargs.pop("n_ret", 1)
596
+ try:
597
+ if torch.cuda.is_available():
598
+ torch.cuda.empty_cache()
599
 
600
+ ret = ncut_run(*args, **kwargs)
601
 
602
+ if torch.cuda.is_available():
603
+ torch.cuda.empty_cache()
604
 
605
+ ret = list(ret)[:n_ret] + [ret[-1]]
606
+ return ret
607
+ except Exception as e:
608
+ gr.Error(str(e))
609
+ if torch.cuda.is_available():
610
+ torch.cuda.empty_cache()
611
+ return *(None for _ in range(n_ret)), "Error: " + str(e)
612
+
613
+ # ret = ncut_run(*args, **kwargs)
614
+ # ret = list(ret)[:n_ret] + [ret[-1]]
615
+ # return ret
616
 
617
  if USE_HUGGINGFACE_ZEROGPU:
618
  @spaces.GPU(duration=20)
 
811
  n_ret=1,
812
  ):
813
 
814
+ progress=gr.Progress()
815
+ progress(0, desc="Starting")
816
+
817
+
818
  if images is None:
819
  gr.Warning("No images selected.")
820
  return *(None for _ in range(n_ret)), "No images selected."
821
 
822
+ progress(0.05, desc="Processing Images")
823
  video_output = False
824
  if isinstance(images, str):
825
  images = extract_video_frames(images, max_frames=max_frames)
 
839
  images = [transform_image(image, resolution=resolution, stablediffusion=stablediffusion) for image in images]
840
  images = torch.stack(images)
841
 
842
+ progress(0.1, desc="Downloading Model")
843
 
844
  if is_lisa:
845
  import subprocess
 
1049
  def load_dataset_images(is_advanced, dataset_name, num_images=10,
1050
  is_filter=True, filter_by_class_text="0,1,2",
1051
  is_random=False, seed=1):
1052
+ progress = gr.Progress()
1053
+ progress(0, desc="Loading Images")
1054
  if is_advanced == "Basic":
1055
  gr.Info("Loaded images from Ego-Exo4D")
1056
  return default_images
1057
  try:
1058
+ progress(0.5, desc="Downloading Dataset")
1059
  dataset = load_dataset(dataset_name, trust_remote_code=True)
1060
  key = list(dataset.keys())[0]
1061
  dataset = dataset[key]
 
1066
  num_images = len(dataset)
1067
 
1068
  if is_filter:
1069
+ progress(0.8, desc="Filtering Images")
1070
  classes = [int(i) for i in filter_by_class_text.split(",")]
1071
  labels = np.array(dataset['label'])
1072
  unique_labels = np.unique(labels)
 
1270
  with gr.Column(scale=5, min_width=200):
1271
  input_gallery, submit_button, clear_images_button = make_input_images_section()
1272
  dataset_dropdown, num_images_slider, random_seed_slider, load_images_button = make_dataset_images_section()
1273
+ num_images_slider.value = 30
1274
  logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information", autofocus=False, autoscroll=False)
1275
 
1276
  with gr.Column(scale=5, min_width=200):
app_text.py CHANGED
@@ -150,6 +150,7 @@ def ncut_run(
150
  min_dist=0.1,
151
  sampling_method="fps",
152
  ):
 
153
  logging_str = ""
154
  if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
155
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
@@ -163,6 +164,7 @@ def ncut_run(
163
 
164
  node_type = node_type.split(":")[0].strip()
165
 
 
166
  model = model.to("cuda" if torch.cuda.is_available() else "cpu")
167
 
168
  start = time.time()
@@ -180,6 +182,7 @@ def ncut_run(
180
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
181
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
182
 
 
183
  rgb, _logging_str, _ = compute_ncut(
184
  features,
185
  num_eig=num_eig,
@@ -197,6 +200,7 @@ def ncut_run(
197
  logging_str += _logging_str
198
 
199
  start = time.time()
 
200
  title = f"{model_name}, Layer {layer}, {node_type}"
201
  fig = make_plot(token_texts, rgb, title=title)
202
  logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
@@ -223,6 +227,8 @@ else:
223
  return _ncut_run(*args, **kwargs)
224
 
225
  def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
 
 
226
  model = TEXT_MODEL_DICT[model_name]()
227
  return __ncut_run(model, text, model_name, layer, num_eig, node_type,
228
  affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method,
@@ -251,7 +257,9 @@ def make_demo():
251
  clear_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
252
  with gr.Column(scale=5, min_width=200):
253
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
254
- model_name = gr.Dropdown(list(TEXT_MODEL_DICT.keys()), label="Model", value="meta-llama/Meta-Llama-3.1-8B")
 
 
255
  layer = gr.Slider(1, 32, step=1, value=32, label="Layer")
256
  node_type = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node Type", value="block: sum of residual")
257
  num_eig = gr.Slider(minimum=1, maximum=1000, step=1, value=100, label="Number of Eigenvectors")
 
150
  min_dist=0.1,
151
  sampling_method="fps",
152
  ):
153
+ progress = gr.Progress()
154
  logging_str = ""
155
  if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
156
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
 
164
 
165
  node_type = node_type.split(":")[0].strip()
166
 
167
+ progress(0.5, desc="Feature Extraction")
168
  model = model.to("cuda" if torch.cuda.is_available() else "cpu")
169
 
170
  start = time.time()
 
182
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
183
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
184
 
185
+ progress(0.6, desc="NCUT & spectral-tSNE")
186
  rgb, _logging_str, _ = compute_ncut(
187
  features,
188
  num_eig=num_eig,
 
200
  logging_str += _logging_str
201
 
202
  start = time.time()
203
+ progress(0.8, desc="Plotting")
204
  title = f"{model_name}, Layer {layer}, {node_type}"
205
  fig = make_plot(token_texts, rgb, title=title)
206
  logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
 
227
  return _ncut_run(*args, **kwargs)
228
 
229
  def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
230
+ progress = gr.Progress()
231
+ progress(0.1, desc="Downloading model")
232
  model = TEXT_MODEL_DICT[model_name]()
233
  return __ncut_run(model, text, model_name, layer, num_eig, node_type,
234
  affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method,
 
257
  clear_button = gr.Button("🗑️Clear", elem_id='clear_button', variant='stop')
258
  with gr.Column(scale=5, min_width=200):
259
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
260
+ model_list = list(TEXT_MODEL_DICT.keys())
261
+ model_list = [model for model in model_list if model != "meta-llama/Meta-Llama-3-8B"]
262
+ model_name = gr.Dropdown(model_list, label="Model", value="meta-llama/Meta-Llama-3.1-8B")
263
  layer = gr.Slider(1, 32, step=1, value=32, label="Layer")
264
  node_type = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node Type", value="block: sum of residual")
265
  num_eig = gr.Slider(minimum=1, maximum=1000, step=1, value=100, label="Number of Eigenvectors")