File size: 2,900 Bytes
0aa495b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "dea2016d-48c1-469a-a3ea-7856850b7725",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/datasets/load.py:1454: FutureWarning: The repository for hf-internal-testing/librispeech_asr_dummy contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hf-internal-testing/librispeech_asr_dummy\n",
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from time import time\n",
"from datasets import load_dataset\n",
"from faster_whisper import WhisperModel\n",
"# from transformers import WhisperForConditionalGeneration, WhisperProcessor\n",
"\n",
"ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\", cache_dir=\".\")\n",
"\n",
"# processor = WhisperProcessor.from_pretrained(\"openai/whisper-large-v3\")\n",
"# model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-large-v3\").to(\"mps\")\n",
"model = WhisperModel(\"large-v3\", device=\"cuda\", compute_type=\"float16\", download_root=\".\")\n",
"\n",
"audio_sample = ds[0][\"audio\"]\n",
"waveform = audio_sample[\"array\"]\n",
"sampling_rate = audio_sample[\"sampling_rate\"]\n",
"\n",
"tic = time()\n",
"# input_features = processor(\n",
"# waveform, sampling_rate=sampling_rate, return_tensors=\"pt\"\n",
"# ).input_features\n",
"segments, info = model.transcribe(waveform, beam_size=5)\n",
"# predicted_ids = model.generate(input_features.to(\"mps\"))\n",
"\n",
"# transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
"\n",
"toc = time()\n",
"\n",
"# print(transcription[0])\n",
"for segment in segments:\n",
" print(\"[%.2fs -> %.2fs] %s\" % (segment.start, segment.end, segment.text))\n",
"print(toc - tic)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c299671e-4f4b-485c-a36f-2a35ea258995",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|