File size: 3,095 Bytes
71de706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch


class WhisperMixin:
    is_initialized = False

    def setup_whisper(
        self,
        pretrained_model_name_or_path: str = "openai/whisper-base.en",
        device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    ):
        from transformers import WhisperForConditionalGeneration
        from transformers import WhisperProcessor

        self.whisper_device = device
        self.whisper_processor = WhisperProcessor.from_pretrained(
            pretrained_model_name_or_path
        )
        self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
            pretrained_model_name_or_path
        ).to(self.whisper_device)
        self.is_initialized = True

    def get_whisper_features(self) -> torch.Tensor:
        """Preprocess audio signal as per the whisper model's training config.

        Returns
        -------
        torch.Tensor
            The prepinput features of the audio signal. Shape: (1, channels, seq_len)
        """
        import torch

        if not self.is_initialized:
            self.setup_whisper()

        signal = self.to(self.device)
        raw_speech = list(
            (
                signal.clone()
                .resample(self.whisper_processor.feature_extractor.sampling_rate)
                .audio_data[:, 0, :]
                .numpy()
            )
        )

        with torch.inference_mode():
            input_features = self.whisper_processor(
                raw_speech,
                sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
                return_tensors="pt",
            ).input_features

        return input_features

    def get_whisper_transcript(self) -> str:
        """Get the transcript of the audio signal using the whisper model.

        Returns
        -------
        str
            The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
        """

        if not self.is_initialized:
            self.setup_whisper()

        input_features = self.get_whisper_features()

        with torch.inference_mode():
            input_features = input_features.to(self.whisper_device)
            generated_ids = self.whisper_model.generate(inputs=input_features)

        transcription = self.whisper_processor.batch_decode(generated_ids)
        return transcription[0]

    def get_whisper_embeddings(self) -> torch.Tensor:
        """Get the last hidden state embeddings of the audio signal using the whisper model.

        Returns
        -------
        torch.Tensor
            The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
        """
        import torch

        if not self.is_initialized:
            self.setup_whisper()

        input_features = self.get_whisper_features()
        encoder = self.whisper_model.get_encoder()

        with torch.inference_mode():
            input_features = input_features.to(self.whisper_device)
            embeddings = encoder(input_features)

        return embeddings.last_hidden_state