jadechoghari commited on
Commit
225b9d3
1 Parent(s): 5085882

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +68 -0
pipeline.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ import os
4
+ import yaml
5
+ from audioldm_train.utilities.tools import build_dataset_json_from_list
6
+ from infer_mos5 import infer # Importing the infer function
7
+
8
+ class MOSDiffusionPipeline(DiffusionPipeline):
9
+
10
+ def __init__(self, config_yaml, list_inference, reload_from_ckpt=None):
11
+ """
12
+ Initialize the MOS Diffusion pipeline.
13
+
14
+ Args:
15
+ config_yaml (str): Path to the YAML configuration file.
16
+ list_inference (str): Path to the file containing inference prompts.
17
+ reload_from_ckpt (str, optional): Checkpoint path to reload from.
18
+ """
19
+ super().__init__()
20
+
21
+ # we load and process the yaml config
22
+ self.config_yaml = config_yaml
23
+ self.list_inference = list_inference
24
+ self.reload_from_ckpt = reload_from_ckpt
25
+
26
+ # we load the yaml config
27
+ config_yaml_path = os.path.join(self.config_yaml)
28
+ self.configs = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader)
29
+
30
+ # override checkpoint if provided--
31
+ if self.reload_from_ckpt is not None:
32
+ self.configs["reload_from_ckpt"] = self.reload_from_ckpt
33
+
34
+ self.dataset_key = build_dataset_json_from_list(self.list_inference)
35
+ self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
36
+ self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
37
+
38
+ @torch.no_grad()
39
+ def __call__(self, *args, **kwargs):
40
+ """
41
+ Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
42
+
43
+ Args:
44
+ *args: Additional arguments.
45
+ **kwargs: Keyword arguments that may contain overrides for configurations.
46
+
47
+ Returns:
48
+ None. Inference is performed and samples are generated.
49
+ """
50
+ # here call the infer function to perform the inference
51
+ infer(
52
+ dataset_key=self.dataset_key,
53
+ configs=self.configs,
54
+ config_yaml_path=self.config_yaml,
55
+ exp_group_name=self.exp_group_name,
56
+ exp_name=self.exp_name
57
+ )
58
+
59
+ # # This is an example of how to use the pipeline
60
+ # if __name__ == "__main__":
61
+ # pipeline = MOSDiffusionPipeline(
62
+ # config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml",
63
+ # list_inference="/content/qa-mdt/test_prompts/good_prompts_1.lst",
64
+ # reload_from_ckpt="/content/qa-mdt/checkpoint_389999.ckpt"
65
+ # )
66
+
67
+ # # Run the pipeline
68
+ # pipeline()