File size: 2,417 Bytes
8aa42aa
6c87da3
fd51276
3e6bd64
8aa42aa
 
 
b769fee
6c87da3
03d9e79
 
8aa42aa
a7c2026
b769fee
6c87da3
 
 
 
 
 
03d9e79
8aa42aa
 
3de00ec
8aa42aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e6bd64
fab4dca
8aa42aa
fab4dca
8aa42aa
fab4dca
8aa42aa
fab4dca
 
8aa42aa
b5f4a75
23a9064
 
03d9e79
fab4dca
a7c2026
8aa42aa
 
 
 
 
fab4dca
8aa42aa
 
 
 
871b8a4
8aa42aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
os.system("pip install git+https://github.com/sanchit-gandhi/whisper-jax.git")

# import whisper
from flask import Flask, jsonify, request
import requests
import time
# from transformers import pipeline
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp



# model = whisper.load_model("small")
# pipe = pipeline(
#     "automatic-speech-recognition",
#     model="openai/whisper-small.en",
#     chunk_length_s=15,
#     device=model.device,
# )
pipe = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.bfloat16, batch_size=16)

app = Flask(__name__)
app.config['TIMEOUT'] = 60 * 10 # 10 mins

@app.route("/")
def indexApi():
    return jsonify({"output": "okay"})

@app.route("/run", methods=['POST'])
def runApi():
    start_time = time.time()

    audio_url = request.form.get("audio_url")

    response = requests.get(audio_url)

    if response.status_code == requests.codes.ok:
        with open("audio.mp3", "wb") as f:
            f.write(response.content)
      
    else:
        return jsonify({
            "result": "Unable to save file, status code:  {response.status_code}" ,
        }), 400

    audio = "audio.mp3"

    # audioOri = whisper.load_audio(audio)
    # audio = whisper.pad_or_trim(audioOri)
    
    # mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
    # _, probs = model.detect_language(mel)
    
    # options = whisper.DecodingOptions(fp16 = False)
    # result = whisper.decode(model, mel, options)

    # test 2
    # ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    # sample = ds[0]["audio"]
    prediction = pipe(audio,  task="transcribe")["text"]
  

    end_time = time.time()
    total_time = end_time - start_time

    return jsonify({
        "audio_url": audio_url,
        "result": prediction,
        "exec_time_sec": total_time
    })

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)
        
# def inference(audio):
#     audio = whisper.load_audio(audio)
#     audio = whisper.pad_or_trim(audio)
    
#     mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
#     _, probs = model.detect_language(mel)
    
#     options = whisper.DecodingOptions(fp16 = False)
#     result = whisper.decode(model, mel, options)
    
#     # print(result.text)
#     return result.text, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)