File size: 947 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import logging
from pathlib import Path

import torch

RUN_NAME = "enhancer_stage2"

logger = logging.getLogger(__name__)


def get_source_url(relpath):
    return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"


def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
    if run_dir is None:
        run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
    return Path(run_dir) / relpath


def download(run_dir: str | Path | None = None):
    relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
    for relpath in relpaths:
        path = get_target_path(relpath, run_dir=run_dir)
        if path.exists():
            continue
        url = get_source_url(relpath)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.hub.download_url_to_file(url, str(path))
    return get_target_path("", run_dir=run_dir)