nithinraok commited on
Commit
3ae07da
1 Parent(s): b6c2bc0

Create viterbi_decoding.py

Browse files
Files changed (1) hide show
  1. viterbi_decoding.py +137 -0
viterbi_decoding.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ V_NEGATIVE_NUM = -3.4e38
17
+
18
+
19
+ def viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device):
20
+ """
21
+ Do Viterbi decoding with an efficient algorithm (the only for-loop in the 'forward pass' is over the time dimension).
22
+ Args:
23
+ log_probs_batch: tensor of shape (B, T_max, V). The parts of log_probs_batch which are 'padding' are filled
24
+ with 'V_NEGATIVE_NUM' - a large negative number which represents a very low probability.
25
+ y_batch: tensor of shape (B, U_max) - contains token IDs including blanks in every other position. The parts of
26
+ y_batch which are padding are filled with the number 'V'. V = the number of tokens in the vocabulary + 1 for
27
+ the blank token.
28
+ T_batch: tensor of shape (B, 1) - contains the durations of the log_probs_batch (so we can ignore the
29
+ parts of log_probs_batch which are padding)
30
+ U_batch: tensor of shape (B, 1) - contains the lengths of y_batch (so we can ignore the parts of y_batch
31
+ which are padding).
32
+ viterbi_device: the torch device on which Viterbi decoding will be done.
33
+
34
+ Returns:
35
+ alignments_batch: list of lists containing locations for the tokens we align to at each timestep.
36
+ Looks like: [[0, 0, 1, 2, 2, 3, 3, ..., ], ..., [0, 1, 2, 2, 2, 3, 4, ....]].
37
+ Each list inside alignments_batch is of length T_batch[location of utt in batch].
38
+ """
39
+
40
+ B, T_max, _ = log_probs_batch.shape
41
+ U_max = y_batch.shape[1]
42
+
43
+ # transfer all tensors to viterbi_device
44
+ log_probs_batch = log_probs_batch.to(viterbi_device)
45
+ y_batch = y_batch.to(viterbi_device)
46
+ T_batch = T_batch.to(viterbi_device)
47
+ U_batch = U_batch.to(viterbi_device)
48
+
49
+ # make tensor that we will put at timesteps beyond the duration of the audio
50
+ padding_for_log_probs = V_NEGATIVE_NUM * torch.ones((B, T_max, 1), device=viterbi_device)
51
+ # make log_probs_padded tensor of shape (B, T_max, V +1 ) where all of
52
+ # log_probs_padded[:,:,-1] is the 'V_NEGATIVE_NUM'
53
+ log_probs_padded = torch.cat((log_probs_batch, padding_for_log_probs), dim=2)
54
+
55
+ # initialize v_prev - tensor of previous timestep's viterbi probabilies, of shape (B, U_max)
56
+ v_prev = V_NEGATIVE_NUM * torch.ones((B, U_max), device=viterbi_device)
57
+ v_prev[:, :2] = torch.gather(input=log_probs_padded[:, 0, :], dim=1, index=y_batch[:, :2])
58
+
59
+ # initialize backpointers_rel - which contains values like 0 to indicate the backpointer is to the same u index,
60
+ # 1 to indicate the backpointer pointing to the u-1 index and 2 to indicate the backpointer is pointing to the u-2 index
61
+ backpointers_rel = -99 * torch.ones((B, T_max, U_max), dtype=torch.int8, device=viterbi_device)
62
+
63
+ # Make a letter_repetition_mask the same shape as y_batch
64
+ # the letter_repetition_mask will have 'True' where the token (including blanks) is the same
65
+ # as the token two places before it in the ground truth (and 'False everywhere else).
66
+ # We will use letter_repetition_mask to determine whether the Viterbi algorithm needs to look two tokens back or
67
+ # three tokens back
68
+ y_shifted_left = torch.roll(y_batch, shifts=2, dims=1)
69
+ letter_repetition_mask = y_batch - y_shifted_left
70
+ letter_repetition_mask[:, :2] = 1 # make sure dont apply mask to first 2 tokens
71
+ letter_repetition_mask = letter_repetition_mask == 0
72
+
73
+ for t in range(1, T_max):
74
+
75
+ # e_current is a tensor of shape (B, U_max) of the log probs of every possible token at the current timestep
76
+ e_current = torch.gather(input=log_probs_padded[:, t, :], dim=1, index=y_batch)
77
+
78
+ # apply a mask to e_current to cope with the fact that we do not keep the whole v_matrix and continue
79
+ # calculating viterbi probabilities during some 'padding' timesteps
80
+ t_exceeded_T_batch = t >= T_batch
81
+
82
+ U_can_be_final = torch.logical_or(
83
+ torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 0),
84
+ torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 1),
85
+ )
86
+
87
+ mask = torch.logical_not(torch.logical_and(t_exceeded_T_batch.unsqueeze(1), U_can_be_final,)).long()
88
+
89
+ e_current = e_current * mask
90
+
91
+ # v_prev_shifted is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 1 token position back
92
+ v_prev_shifted = torch.roll(v_prev, shifts=1, dims=1)
93
+ # by doing a roll shift of size 1, we have brought the viterbi probability in the final token position to the
94
+ # first token position - let's overcome this by 'zeroing out' the probabilities in the firest token position
95
+ v_prev_shifted[:, 0] = V_NEGATIVE_NUM
96
+
97
+ # v_prev_shifted2 is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 2 token position back
98
+ v_prev_shifted2 = torch.roll(v_prev, shifts=2, dims=1)
99
+ v_prev_shifted2[:, :2] = V_NEGATIVE_NUM # zero out as we did for v_prev_shifted
100
+ # use our letter_repetition_mask to remove the connections between 2 blanks (so we don't skip over a letter)
101
+ # and to remove the connections between 2 consective letters (so we don't skip over a blank)
102
+ v_prev_shifted2.masked_fill_(letter_repetition_mask, V_NEGATIVE_NUM)
103
+
104
+ # we need this v_prev_dup tensor so we can calculated the viterbi probability of every possible
105
+ # token position simultaneously
106
+ v_prev_dup = torch.cat(
107
+ (v_prev.unsqueeze(2), v_prev_shifted.unsqueeze(2), v_prev_shifted2.unsqueeze(2),), dim=2,
108
+ )
109
+
110
+ # candidates_v_current are our candidate viterbi probabilities for every token position, from which
111
+ # we will pick the max and record the argmax
112
+ candidates_v_current = v_prev_dup + e_current.unsqueeze(2)
113
+ # we straight away save results in v_prev instead of v_current, so that the variable v_prev will be ready for the
114
+ # next iteration of the for-loop
115
+ v_prev, bp_relative = torch.max(candidates_v_current, dim=2)
116
+
117
+ backpointers_rel[:, t, :] = bp_relative
118
+
119
+ # trace backpointers
120
+ alignments_batch = []
121
+ for b in range(B):
122
+ T_b = int(T_batch[b])
123
+ U_b = int(U_batch[b])
124
+
125
+ if U_b == 1: # i.e. we put only a blank token in the reference text because the reference text is empty
126
+ current_u = 0 # set initial u to 0 and let the rest of the code block run as usual
127
+ else:
128
+ current_u = int(torch.argmax(v_prev[b, U_b - 2 : U_b])) + U_b - 2
129
+ alignment_b = [current_u]
130
+ for t in range(T_max - 1, 0, -1):
131
+ current_u = current_u - int(backpointers_rel[b, t, current_u])
132
+ alignment_b.insert(0, current_u)
133
+ alignment_b = alignment_b[:T_b]
134
+ alignments_batch.append(alignment_b)
135
+
136
+ return alignments_batch
137
+