import torch class WhisperMixin: is_initialized = False def setup_whisper( self, pretrained_model_name_or_path: str = "openai/whisper-base.en", device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ): from transformers import WhisperForConditionalGeneration from transformers import WhisperProcessor self.whisper_device = device self.whisper_processor = WhisperProcessor.from_pretrained( pretrained_model_name_or_path ) self.whisper_model = WhisperForConditionalGeneration.from_pretrained( pretrained_model_name_or_path ).to(self.whisper_device) self.is_initialized = True def get_whisper_features(self) -> torch.Tensor: """Preprocess audio signal as per the whisper model's training config. Returns ------- torch.Tensor The prepinput features of the audio signal. Shape: (1, channels, seq_len) """ import torch if not self.is_initialized: self.setup_whisper() signal = self.to(self.device) raw_speech = list( ( signal.clone() .resample(self.whisper_processor.feature_extractor.sampling_rate) .audio_data[:, 0, :] .numpy() ) ) with torch.inference_mode(): input_features = self.whisper_processor( raw_speech, sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, return_tensors="pt", ).input_features return input_features def get_whisper_transcript(self) -> str: """Get the transcript of the audio signal using the whisper model. Returns ------- str The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>. """ if not self.is_initialized: self.setup_whisper() input_features = self.get_whisper_features() with torch.inference_mode(): input_features = input_features.to(self.whisper_device) generated_ids = self.whisper_model.generate(inputs=input_features) transcription = self.whisper_processor.batch_decode(generated_ids) return transcription[0] def get_whisper_embeddings(self) -> torch.Tensor: """Get the last hidden state embeddings of the audio signal using the whisper model. Returns ------- torch.Tensor The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size) """ import torch if not self.is_initialized: self.setup_whisper() input_features = self.get_whisper_features() encoder = self.whisper_model.get_encoder() with torch.inference_mode(): input_features = input_features.to(self.whisper_device) embeddings = encoder(input_features) return embeddings.last_hidden_state