jlonsako commited on
Commit
394d8c1
1 Parent(s): 9437cb7

Added padding to batch_decode to handle differing audio sample sizes

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -108,7 +108,7 @@ def Transcribe(file, batch_size):
108
  # If the batch is full, process it
109
  if len(batch) == batch_size:
110
  # Concatenate all segments in the batch along the time axis
111
- input_values = processor(batch, sampling_rate=16_000, return_tensors="pt")
112
  input_values = input_values.to(device)
113
  with torch.no_grad():
114
  logits = model(**input_values).logits
@@ -141,7 +141,7 @@ def Transcribe(file, batch_size):
141
 
142
  if batch:
143
  # Concatenate all segments in the batch along the time axis
144
- input_values = processor(batch, sampling_rate=16_000, return_tensors="pt")
145
  input_values = input_values.to(device)
146
  with torch.no_grad():
147
  logits = model(**input_values).logits
 
108
  # If the batch is full, process it
109
  if len(batch) == batch_size:
110
  # Concatenate all segments in the batch along the time axis
111
+ input_values = processor(batch, sampling_rate=16_000, return_tensors="pt", padding=True)
112
  input_values = input_values.to(device)
113
  with torch.no_grad():
114
  logits = model(**input_values).logits
 
141
 
142
  if batch:
143
  # Concatenate all segments in the batch along the time axis
144
+ input_values = processor(batch, sampling_rate=16_000, return_tensors="pt", padding=True)
145
  input_values = input_values.to(device)
146
  with torch.no_grad():
147
  logits = model(**input_values).logits