from diffusers import DiffusionPipeline import torch import os import yaml from audioldm_train.utilities.tools import build_dataset_json_from_list from infer_mos5 import infer # Importing the infer function class MOSDiffusionPipeline(DiffusionPipeline): def __init__(self, config_yaml, list_inference, reload_from_ckpt=None): """ Initialize the MOS Diffusion pipeline. Args: config_yaml (str): Path to the YAML configuration file. list_inference (str): Path to the file containing inference prompts. reload_from_ckpt (str, optional): Checkpoint path to reload from. """ super().__init__() # we load and process the yaml config self.config_yaml = config_yaml self.list_inference = list_inference self.reload_from_ckpt = reload_from_ckpt # we load the yaml config config_yaml_path = os.path.join(self.config_yaml) self.configs = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) # override checkpoint if provided-- if self.reload_from_ckpt is not None: self.configs["reload_from_ckpt"] = self.reload_from_ckpt self.dataset_key = build_dataset_json_from_list(self.list_inference) self.exp_name = os.path.basename(self.config_yaml.split(".")[0]) self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml)) @torch.no_grad() def __call__(self, *args, **kwargs): """ Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py. Args: *args: Additional arguments. **kwargs: Keyword arguments that may contain overrides for configurations. Returns: None. Inference is performed and samples are generated. """ # here call the infer function to perform the inference infer( dataset_key=self.dataset_key, configs=self.configs, config_yaml_path=self.config_yaml, exp_group_name=self.exp_group_name, exp_name=self.exp_name ) # # This is an example of how to use the pipeline # if __name__ == "__main__": # pipeline = MOSDiffusionPipeline( # config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml", # list_inference="/content/qa-mdt/test_prompts/good_prompts_1.lst", # reload_from_ckpt="/content/qa-mdt/checkpoint_389999.ckpt" # ) # # Run the pipeline # pipeline()