File size: 3,466 Bytes
225b9d3
 
893807d
 
 
f38a7f2
225b9d3
893807d
225b9d3
893807d
225b9d3
893807d
225b9d3
 
 
 
 
893807d
225b9d3
 
 
893807d
 
 
225b9d3
 
 
 
893807d
225b9d3
 
 
bf7634c
225b9d3
 
 
893807d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225b9d3
 
 
 
 
61de3b8
893807d
225b9d3
893807d
 
 
 
225b9d3
 
 
893807d
 
 
 
 
 
 
 
225b9d3
893807d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from diffusers import DiffusionPipeline
import os
import sys
from huggingface_hub import HfApi, hf_hub_download
from .tools import build_dataset_json_from_list
import torch

class MOSDiffusionPipeline(DiffusionPipeline):

    def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None):
        """
        Initialize the MOS Diffusion pipeline and download the necessary files/folders.

        Args:
            config_yaml (str): Path to the YAML configuration file.
            list_inference (str): Path to the file containing inference prompts.
            reload_from_ckpt (str, optional): Checkpoint path to reload from.
            base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory.
        """
        super().__init__()


        self.base_folder = base_folder if base_folder else os.getcwd()
        self.repo_id = "jadechoghari/qa-mdt" 
        self.config_yaml = config_yaml
        self.list_inference = list_inference
        self.reload_from_ckpt = reload_from_ckpt
        config_yaml_path = os.path.join(self.config_yaml)
        self.configs = self.load_yaml(config_yaml_path)
        if self.reload_from_ckpt is not None:
            self.configs["reload_from_ckpt"] = self.reload_from_ckpt

        self.dataset_key = build_dataset_json_from_list(self.list_inference)
        self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
        self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))

    def download_required_folders(self):
        """
        Downloads the necessary folders from the Hugging Face Hub if they are not already available locally.
        """
        api = HfApi()

        files = api.list_repo_files(repo_id=self.repo_id)

        required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"]

        files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)]

        for file in files_to_download:
            local_file_path = os.path.join(self.base_folder, file)
            if not os.path.exists(local_file_path):
                downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file)

                os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

                os.rename(downloaded_file, local_file_path)

        sys.path.append(self.base_folder)

    def load_yaml(self, yaml_path):
        """
        Helper method to load the YAML configuration.
        """
        import yaml
        with open(yaml_path, "r") as f:
            return yaml.safe_load(f)


    @torch.no_grad()
    def __call__(self, *args, **kwargs):
        """
        Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
        """
        from .infer.infer_mos5 import infer

        infer(
            dataset_key=self.dataset_key,
            configs=self.configs,
            config_yaml_path=self.config_yaml,
            exp_group_name=self.exp_group_name,
            exp_name=self.exp_name
        )

# Example of how to use the pipeline
if __name__ == "__main__":
    pipeline = MOSDiffusionPipeline(
        config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml",
        list_inference="test_prompts/good_prompts_1.lst",
        reload_from_ckpt="checkpoints/checkpoint_389999.ckpt",
        base_folder=None
    )

    # Run the pipeline
    pipeline()