zhzluke96
update
32b2aaa
raw
history blame
No virus
790 Bytes
import logging
from functools import cache
import torch
from ..denoiser.denoiser import Denoiser
from ..inference import inference
from .hparams import HParams
logger = logging.getLogger(__name__)
@cache
def load_denoiser(run_dir, device):
if run_dir is None:
return Denoiser(HParams())
hp = HParams.load(run_dir)
denoiser = Denoiser(hp)
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
state_dict = torch.load(path, map_location="cpu")["module"]
denoiser.load_state_dict(state_dict)
denoiser.eval()
denoiser.to(device)
return denoiser
@torch.inference_mode()
def denoise(dwav, sr, run_dir, device):
denoiser = load_denoiser(run_dir, device)
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)