Tsumugii24 commited on
Commit
a59b721
1 Parent(s): 95be552

add model auto downloads

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -27,11 +27,11 @@ data_url_dict = {
27
  }
28
 
29
  model_url_dict = {
30
- "cnn_se.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/cnn_se.pt",
31
- "detr_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/detr_based.pt",
32
- "vit_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/vit_based.pt",
33
- "yolov5_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/yolov5_based.pt",
34
- "yolov8_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/raw/main/yolov8_based.pt",
35
  }
36
 
37
  # 判断字体文件是否存在
@@ -77,7 +77,7 @@ def download_fonts(font_diff):
77
  for k, v in data_url_dict.items():
78
  if k in font_diff:
79
  font_name = v.split("/")[-1] # 字体名称
80
- fonts_directory_path.mkdir(parents=True, exist_ok=True) # 创建目录
81
 
82
  font_file_path = f"{ROOT_PATH}/fonts/{font_name}" # 字体路径
83
  # 下载字体文件
@@ -87,13 +87,14 @@ def download_fonts(font_diff):
87
  def download_models(model_diff):
88
  global model_name
89
 
90
- for k in model_diff:
91
- v = model_url_dict[k]
92
- model_name = v.split("/")[-1] # 模型名称
 
93
 
94
- model_file_path = f"{ROOT_PATH}/models/{model_name}" # 模型路径
95
- # 下载模型文件
96
- wget.download(v, model_file_path)
97
 
98
 
99
  is_fonts(fonts_directory_path)
 
27
  }
28
 
29
  model_url_dict = {
30
+ "cnn_se.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/resolve/main/cnn_se.pt",
31
+ "detr_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/resolve/main/detr_based.pt",
32
+ "vit_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/resolve/main/vit_based.pt",
33
+ "yolov5_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/resolve/main/yolov5_based.pt",
34
+ "yolov8_based.pt": "https://huggingface.co/Tsumugii/lesion-cells-det/resolve/main/yolov8_based.pt",
35
  }
36
 
37
  # 判断字体文件是否存在
 
77
  for k, v in data_url_dict.items():
78
  if k in font_diff:
79
  font_name = v.split("/")[-1] # 字体名称
80
+ fonts_directory_path.mkdir(parents=True, exist_ok=True) # 创建本地字体目录
81
 
82
  font_file_path = f"{ROOT_PATH}/fonts/{font_name}" # 字体路径
83
  # 下载字体文件
 
87
  def download_models(model_diff):
88
  global model_name
89
 
90
+ for k, v in model_url_dict.items():
91
+ if k in model_diff:
92
+ model_name = v.split("/")[-1] # 模型名称
93
+ models_directory_path.mkdir(parents=True, exist_ok=True) # 创建本地模型目录
94
 
95
+ model_file_path = f"{ROOT_PATH}/models/{model_name}" # 模型路径
96
+ # 下载模型文件
97
+ wget.download(v, model_file_path)
98
 
99
 
100
  is_fonts(fonts_directory_path)