# This module is from [WeNet](https://github.com/wenet-e2e/wenet). # ## Citations # ```bibtex # @inproceedings{yao2021wenet, # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, # booktitle={Proc. Interspeech}, # year={2021}, # address={Brno, Czech Republic }, # organization={IEEE} # } # @article{zhang2022wenet, # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, # journal={arXiv preprint arXiv:2203.15455}, # year={2022} # } # """Unility functions for Transformer.""" import math from typing import List, Tuple import torch from torch.nn.utils.rnn import pad_sequence IGNORE_ID = -1 def pad_list(xs: List[torch.Tensor], pad_value: int): """Perform padding for the list of tensors. Args: xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. pad_value (float): Value for padding. Returns: Tensor: Padded tensor (B, Tmax, `*`). Examples: >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] >>> x [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] >>> pad_list(x, 0) tensor([[1., 1., 1., 1.], [1., 1., 0., 0.], [1., 0., 0., 0.]]) """ n_batch = len(xs) max_len = max([x.size(0) for x in xs]) pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device) pad = pad.fill_(pad_value) for i in range(n_batch): pad[i, : xs[i].size(0)] = xs[i] return pad def add_blank(ys_pad: torch.Tensor, blank: int, ignore_id: int) -> torch.Tensor: """Prepad blank for transducer predictor Args: ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) blank (int): index of Returns: ys_in (torch.Tensor) : (B, Lmax + 1) Examples: >>> blank = 0 >>> ignore_id = -1 >>> ys_pad tensor([[ 1, 2, 3, 4, 5], [ 4, 5, 6, -1, -1], [ 7, 8, 9, -1, -1]], dtype=torch.int32) >>> ys_in = add_blank(ys_pad, 0, -1) >>> ys_in tensor([[0, 1, 2, 3, 4, 5], [0, 4, 5, 6, 0, 0], [0, 7, 8, 9, 0, 0]]) """ bs = ys_pad.size(0) _blank = torch.tensor( [blank], dtype=torch.long, requires_grad=False, device=ys_pad.device ) _blank = _blank.repeat(bs).unsqueeze(1) # [bs,1] out = torch.cat([_blank, ys_pad], dim=1) # [bs, Lmax+1] return torch.where(out == ignore_id, blank, out) def add_sos_eos( ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int ) -> Tuple[torch.Tensor, torch.Tensor]: """Add and labels. Args: ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) sos (int): index of eos (int): index of ignore_id (int): index of padding Returns: ys_in (torch.Tensor) : (B, Lmax + 1) ys_out (torch.Tensor) : (B, Lmax + 1) Examples: >>> sos_id = 10 >>> eos_id = 11 >>> ignore_id = -1 >>> ys_pad tensor([[ 1, 2, 3, 4, 5], [ 4, 5, 6, -1, -1], [ 7, 8, 9, -1, -1]], dtype=torch.int32) >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) >>> ys_in tensor([[10, 1, 2, 3, 4, 5], [10, 4, 5, 6, 11, 11], [10, 7, 8, 9, 11, 11]]) >>> ys_out tensor([[ 1, 2, 3, 4, 5, 11], [ 4, 5, 6, 11, -1, -1], [ 7, 8, 9, 11, -1, -1]]) """ _sos = torch.tensor( [sos], dtype=torch.long, requires_grad=False, device=ys_pad.device ) _eos = torch.tensor( [eos], dtype=torch.long, requires_grad=False, device=ys_pad.device ) ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys ys_in = [torch.cat([_sos, y], dim=0) for y in ys] ys_out = [torch.cat([y, _eos], dim=0) for y in ys] return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) def reverse_pad_list( ys_pad: torch.Tensor, ys_lens: torch.Tensor, pad_value: float = -1.0 ) -> torch.Tensor: """Reverse padding for the list of tensors. Args: ys_pad (tensor): The padded tensor (B, Tokenmax). ys_lens (tensor): The lens of token seqs (B) pad_value (int): Value for padding. Returns: Tensor: Padded tensor (B, Tokenmax). Examples: >>> x tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) >>> pad_list(x, 0) tensor([[4, 3, 2, 1], [7, 6, 5, 0], [9, 8, 0, 0]]) """ r_ys_pad = pad_sequence( [(torch.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)], True, pad_value, ) return r_ys_pad def th_accuracy( pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int ) -> float: """Calculate accuracy. Args: pad_outputs (Tensor): Prediction tensors (B * Lmax, D). pad_targets (LongTensor): Target label tensors (B, Lmax). ignore_label (int): Ignore label id. Returns: float: Accuracy value (0.0 - 1.0). """ pad_pred = pad_outputs.view( pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) ).argmax(2) mask = pad_targets != ignore_label numerator = torch.sum( pad_pred.masked_select(mask) == pad_targets.masked_select(mask) ) denominator = torch.sum(mask) return float(numerator) / float(denominator) def get_rnn(rnn_type: str) -> torch.nn.Module: assert rnn_type in ["rnn", "lstm", "gru"] if rnn_type == "rnn": return torch.nn.RNN elif rnn_type == "lstm": return torch.nn.LSTM else: return torch.nn.GRU def get_activation(act): """Return activation function.""" # Lazy load to avoid unused import from modules.wenet_extractor.transformer.swish import Swish activation_funcs = { "hardtanh": torch.nn.Hardtanh, "tanh": torch.nn.Tanh, "relu": torch.nn.ReLU, "selu": torch.nn.SELU, "swish": getattr(torch.nn, "SiLU", Swish), "gelu": torch.nn.GELU, } return activation_funcs[act]() def get_subsample(config): input_layer = config["encoder_conf"]["input_layer"] assert input_layer in ["conv2d", "conv2d6", "conv2d8"] if input_layer == "conv2d": return 4 elif input_layer == "conv2d6": return 6 elif input_layer == "conv2d8": return 8 def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): if hyp[cur] != 0: new_hyp.append(hyp[cur]) prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp def replace_duplicates_with_blank(hyp: List[int]) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): new_hyp.append(hyp[cur]) prev = cur cur += 1 while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0: new_hyp.append(0) cur += 1 return new_hyp def log_add(args: List[int]) -> float: """ Stable log add """ if all(a == -float("inf") for a in args): return -float("inf") a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp