nithinraok commited on
Commit
ea18850
1 Parent(s): 39684bb

Update nemo_align.py

Browse files
Files changed (1) hide show
  1. nemo_align.py +10 -3
nemo_align.py CHANGED
@@ -440,8 +440,15 @@ def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath):
440
  model.change_decoding_strategy(decoder_type="ctc")
441
  else:
442
  raise ValueError("Currently supporting hybrid models")
443
-
444
- with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
 
 
 
 
 
 
 
445
  with torch.inference_mode():
446
  hypotheses = model.transcribe([audio_filepath], return_hypotheses=True, batch_size=1)
447
 
@@ -498,7 +505,7 @@ def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath):
498
  model.preprocessor.featurizer.hop_length * model_downsample_factor / model.cfg.preprocessor.sample_rate
499
  )
500
 
501
- alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, torch.device('cuda'))
502
 
503
 
504
  utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignments_batch[0], output_timestep_duration)
 
440
  model.change_decoding_strategy(decoder_type="ctc")
441
  else:
442
  raise ValueError("Currently supporting hybrid models")
443
+
444
+ if torch.cuda.is_available():
445
+ enable = True
446
+ viterbi_device = torch.device('cuda')
447
+ else:
448
+ enable = False
449
+ viterbi_device = torch.device('cpu')
450
+
451
+ with torch.cuda.amp.autocast(enabled=enable, dtype=torch.bfloat16):
452
  with torch.inference_mode():
453
  hypotheses = model.transcribe([audio_filepath], return_hypotheses=True, batch_size=1)
454
 
 
505
  model.preprocessor.featurizer.hop_length * model_downsample_factor / model.cfg.preprocessor.sample_rate
506
  )
507
 
508
+ alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device)
509
 
510
 
511
  utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignments_batch[0], output_timestep_duration)