Edit model card

Wav2Vec2-2-Bart-Large-Tedlium

This model is a sequence-2-sequence (seq2seq) model trained on the TEDLIUM corpus (release 3).

It combines a speech encoder with a text decoder to perform automatic speech recognition. The encoder weights are initialised with the Wav2Vec2 LV-60k checkpoint from @facebook. The decoder weights are initialised with the Bart large checkpoint from @facebook.

When using the model, make sure that your speech input is sampled at 16Khz.

The model achieves a word error rate (WER) of 9.0% on the dev set and 6.4% on the test set. Training logs document the training and evaluation progress over 50k steps of fine-tuning.

Usage

To transcribe audio files the model can be used as a standalone acoustic model as follows:

 from transformers import AutoProcessor, SpeechEncoderDecoderModel
 from datasets import load_dataset
 import torch
 
 # load model and processor
 processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
 model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
     
 # load dummy dataset
 ds = load_dataset("sanchit-gandhi/tedlium_dummy", split="validation")
 
 # process audio inputs
 input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values  # Batch size 1
 
 # run inference (greedy search)
 generated = model.generate(input_values)
 
 # decode
 decoded = processor.batch_decode(generated, skip_special_tokens=True)
 print("Target: ", ds["text"][0])
 print("Transcription: ", decoded[0])

Evaluation

This code snippet shows how to evaluate Wav2Vec2-Large-Tedlium on the TEDLIUM test data.

from datasets import load_dataset
from transformers import AutoProcessor, SpeechEncoderDecoderModel
import torch
from jiwer import wer

tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")

def filter_ds(text):
    return text != "ignore_time_segment_in_scoring"

# remove samples ignored from scoring
tedlium_eval = tedlium_eval.map(filter_ds, input_columns=["text"])

model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium").to("cuda")
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")

gen_kwargs = {
        "max_length": 200,
        "num_beams": 5,
        "length_penalty": 1.2
        }

def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        generated = model.generate(input_values.to("cuda"), **gen_kwargs)
    decoded = processor.batch_decode(generated, skip_special_tokens=True)
    batch["transcription"] = decoded[0]
    return batch

result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))
Downloads last month
8
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train sanchit-gandhi/wav2vec2-2-bart-large-tedlium