ASR-comparaison / app.py
Steveeeeeeen's picture
Update app.py
f0f7172 verified
raw
history blame
No virus
3.4 kB
from datasets import load_dataset, Dataset
from transformers import pipeline
import evaluate
import numpy as np
import gradio as gr
import json
from pathlib import Path
# Load WER metric
wer_metric = evaluate.load("wer")
model_name = {
"whisper-tiny": "openai/whisper-tiny.en",
"wav2vec2-large-960h": "facebook/wav2vec2-base-960h",
"distill-whisper-small": "distil-whisper/distil-small.en",
}
# open ds_data.json
with open("ds_data.json", "r") as f:
table_data = json.load(f)
def compute_wer_table(audio, text):
# Convert the wav into an array
audio_input = audio[1]
audio_input = audio_input.astype(np.float32)
audio_input = audio_input / 32767
trans = []
wer_scores = []
for model in model_name:
pipe = pipeline("automatic-speech-recognition", model=model_name[model])
transcription = pipe(audio_input)['text']
transcription = "".join([char for char in transcription if char.isalpha() or char.isspace()])
trans.append(transcription)
wer = wer_metric.compute(predictions=[transcription.upper()], references=[text.upper()])
wer_scores.append(wer)
result = [[model, t, s] for model, t, s in zip(model_name.keys(), trans, wer_scores)]
return result
with gr.Blocks() as demo:
with gr.Tab("Docs"):
gr.Markdown((Path(__file__).parent / "demo.md").read_text())
with gr.Tab("Demo"):
gr.Interface(
fn=compute_wer_table,
inputs=[
gr.Audio(label="Input Audio"),
gr.Textbox(label="Reference Text")
],
outputs=gr.Dataframe(headers=["Model", "Transcription", "WER"], label="WER Results"),
examples=[[f"assets/output_audio_{i}.wav", table_data[i]['reference']] for i in range(100)],
title="ASR Model Evaluation",
description=(
"This application allows you to evaluate the performance of various Automatic Speech Recognition (ASR) models on "
"a given audio sample. Simply provide an audio file and the corresponding reference text, and the app will compute "
"the Word Error Rate (WER) for each model. The results will be presented in a table that includes the model name, "
"the transcribed text, and the calculated WER. "
"\n\n### Table of Results\n"
"The table below shows the transcriptions generated by different ASR models, along with their corresponding WER scores. "
"Lower WER scores indicate better performance."
"\n\n| Model | WER |\n"
"|--------------------------|--------------------------|\n"
"| [whisper-tiny](https://huggingface.co/openai/whisper-tiny.en) | 0.06052 |\n"
"| [wav2vec2-large-960h](https://huggingface.co/facebook/wav2vec2-large-960h) | 0.02201 |\n"
"| [distill-whisper-small](https://huggingface.co/distil-whisper/distil-small.en)| 0.03959 |\n"
"\n\n### Data Source\n"
"The data used in this demo is a subset of the [LibriSpeech](https://huggingface.co/datasets/openslr/librispeech_asr) dataset which contains the first 100 audio samples and their corresponding reference texts in the validation set."
),
)
demo.launch()