File size: 1,401 Bytes
94f04b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import json
import requests
from tqdm import tqdm
from config import SAPIENS_LITE_MODELS

def download_file(url, filename):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as progress_bar:
        for data in response.iter_content(chunk_size=1024):
            size = file.write(data)
            progress_bar.update(size)

def main():
    # Load the JSON file with model URLs
    model_urls = SAPIENS_LITE_MODELS

    for task, models in model_urls.items():
        checkpoints_dir = os.path.join('checkpoints', task)
        os.makedirs(checkpoints_dir, exist_ok=True)
        
        for model_name, url in models.items():
            model_filename = f"{model_name}_torchscript.pt2"
            model_path = os.path.join(checkpoints_dir, model_filename)
            
            if not os.path.exists(model_path):
                print(f"Downloading {task} {model_name} model...")
                download_file(url, model_path)
                print(f"{task} {model_name} model downloaded successfully.")
            else:
                print(f"{task} {model_name} model already exists. Skipping download.")

if __name__ == "__main__":
    main()