import os import json import requests from tqdm import tqdm from config import SAPIENS_LITE_MODELS def download_file(url, filename): response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(filename, 'wb') as file, tqdm( desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as progress_bar: for data in response.iter_content(chunk_size=1024): size = file.write(data) progress_bar.update(size) def main(): # Load the JSON file with model URLs model_urls = SAPIENS_LITE_MODELS for task, models in model_urls.items(): checkpoints_dir = os.path.join('checkpoints', task) os.makedirs(checkpoints_dir, exist_ok=True) for model_name, url in models.items(): model_filename = f"{model_name}_torchscript.pt2" model_path = os.path.join(checkpoints_dir, model_filename) if not os.path.exists(model_path): print(f"Downloading {task} {model_name} model...") download_file(url, model_path) print(f"{task} {model_name} model downloaded successfully.") else: print(f"{task} {model_name} model already exists. Skipping download.") if __name__ == "__main__": main()