huzey commited on
Commit
112a402
1 Parent(s): 339b928

fix text lisa

Browse files
Files changed (1) hide show
  1. app_text.py +12 -0
app_text.py CHANGED
@@ -21,6 +21,10 @@ import matplotlib.pyplot as plt
21
  import matplotlib.colors as mcolors
22
  import numpy as np
23
 
 
 
 
 
24
  from ncut_pytorch import NCUT, eigenvector_to_rgb
25
 
26
  from ncut_pytorch.backbone_text import MODEL_DICT as TEXT_MODEL_DICT
@@ -227,6 +231,14 @@ else:
227
  return _ncut_run(*args, **kwargs)
228
 
229
  def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
 
 
 
 
 
 
 
 
230
  progress = gr.Progress()
231
  progress(0.1, desc="Downloading model")
232
  model = TEXT_MODEL_DICT[model_name]()
 
21
  import matplotlib.colors as mcolors
22
  import numpy as np
23
 
24
+ import subprocess
25
+ import sys
26
+ import importlib
27
+
28
  from ncut_pytorch import NCUT, eigenvector_to_rgb
29
 
30
  from ncut_pytorch.backbone_text import MODEL_DICT as TEXT_MODEL_DICT
 
231
  return _ncut_run(*args, **kwargs)
232
 
233
  def real_run(model_name, text, layer, node_type, num_eig, affinity_focal_gamma, num_sample_ncut, knn_ncut, embedding_method, num_sample_tsne, knn_tsne, perplexity, n_neighbors, min_dist, sampling_method):
234
+
235
+ # remove the LISA model from the sys.path
236
+
237
+ if "/tmp/lisa_transformers_v433" in sys.path:
238
+ sys.path.remove("/tmp/lisa_transformers_v433")
239
+
240
+ transformers = importlib.import_module("transformers")
241
+
242
  progress = gr.Progress()
243
  progress(0.1, desc="Downloading model")
244
  model = TEXT_MODEL_DICT[model_name]()