File size: 4,690 Bytes
0d80816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import json
import numpy as np
from text import text_to_sequence
from text.text_token_collation import phoneIDCollation
from models.tts.base.tts_dataset import (
    TTSDataset, 
    TTSCollator,
    TTSTestDataset,
    TTSTestCollator
)


class VITSDataset(TTSDataset):
    def __init__(self, cfg, dataset, is_valid):
        super().__init__(cfg, dataset, is_valid=is_valid)

    def __getitem__(self, index):
        single_feature = super().__getitem__(index)
        return single_feature

    def __len__(self):
        return super().__len__()
    
    def get_metadata(self):
        metadata_filter = []
        with open(self.metafile_path, "r", encoding="utf-8") as f:
            metadata = json.load(f)
        for utt_info in metadata:
            duration = utt_info['Duration']
            frame_len = duration * self.cfg.preprocess.sample_rate // self.cfg.preprocess.hop_size
            if frame_len < self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size:
                continue
            
            metadata_filter.append(utt_info)

        return metadata_filter


class VITSCollator(TTSCollator):
    """Zero-pads model inputs and targets based on number of frames per step"""

    def __init__(self, cfg):
        super().__init__(cfg)

    def __call__(self, batch):
        parsed_batch_features = super().__call__(batch)
        return parsed_batch_features


class VITSTestDataset(TTSTestDataset):
    def __init__(self, args, cfg):
        super().__init__(args, cfg)

        if cfg.preprocess.use_spkid:
            processed_data_dir = os.path.join(cfg.preprocess.processed_dir, args.dataset)
            spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
            with open(spk2id_path, "r") as f:
                self.spk2id = json.load(f)
            
            utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
            self.utt2spk = dict()
            with open(utt2spk_path, "r") as f:
                for line in f.readlines():
                    utt, spk = line.strip().split('\t')
                    self.utt2spk[utt] = spk

        if cfg.preprocess.use_text or cfg.preprocess.use_phone:
            self.utt2seq = {}
            for utt_info in self.metadata:
                dataset = utt_info["Dataset"]
                uid = utt_info["Uid"]
                utt = "{}_{}".format(dataset, uid)

                if cfg.preprocess.use_text:
                    text = utt_info["Text"]
                    sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
                elif cfg.preprocess.use_phone:
                    # load phoneme squence from phone file
                    phone_path = os.path.join(processed_data_dir, 
                                            cfg.preprocess.phone_dir,
                                            uid+'.phone'
                                            )
                    with open(phone_path, 'r') as fin:
                        phones = fin.readlines()
                        assert len(phones) == 1
                        phones = phones[0].strip()
                    phones_seq = phones.split(' ')
                    
                    phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
                    sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)

                self.utt2seq[utt] = sequence

    def __getitem__(self, index):
        utt_info = self.metadata[index]

        dataset = utt_info["Dataset"]
        uid = utt_info["Uid"]
        utt = "{}_{}".format(dataset, uid)

        single_feature = dict()
        
        if self.cfg.preprocess.use_spkid:
            single_feature["spk_id"] = np.array(
                [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
            )
            
        if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
            single_feature["phone_seq"] = np.array(self.utt2seq[utt])
            single_feature["phone_len"] = len(self.utt2seq[utt])

        return single_feature

    def get_metadata(self):
        with open(self.metafile_path, "r", encoding="utf-8") as f:
            metadata = json.load(f)
        return metadata

    def __len__(self):
        return len(self.metadata)


class VITSTestCollator(TTSTestCollator):
    """Zero-pads model inputs and targets based on number of frames per step"""

    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, batch):
        return super().__call__(batch)