Sabbah13's picture
Update app.py
361b6e2 verified
raw
history blame
No virus
4.07 kB
import requests
import base64
import os
import json
import streamlit as st
import whisperx
import torch
def convert_segments_to_text(segments):
result = []
for segment in segments:
speaker = segment['speaker']
start = segment['start']
end = segment['end']
text = segment['text']
formatted_text = f'{speaker} ({start} : {end}) : {text}'
result.append(formatted_text)
return '\n'.join(result)
st.title('Audio Transcription App')
st.sidebar.title("Settings")
# Sidebar inputs
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1)
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16)
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0)
ACCESS_TOKEN = st.secrets["HF_TOKEN"]
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])
if uploaded_file is not None:
st.audio(uploaded_file)
file_extension = uploaded_file.name.split(".")[-1] # Получаем расширение файла
temp_file_path = f"temp_file.{file_extension}" # Создаем временное имя файла с правильным расширением
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with st.spinner('Транскрибируем...'):
# Load model
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
# Load and transcribe audio
audio = whisperx.load_audio(temp_file_path)
result = model.transcribe(audio, batch_size=batch_size, language="ru")
print('Transcribed, now diarizing')
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
diarize_segments = diarize_model(audio)
result_diar = whisperx.assign_word_speakers(diarize_segments, result)
st.text(result_diar)
st.write("Результат транскрибации:")
transcript = convert_segments_to_text(result_diar)
st.text(transcript)
with st.spinner('Резюмируем...'):
username = st.secrets["GIGA_USERNAME"]
password = st.secrets["GIGA_SECRET"]
# Получаем строку с базовой авторизацией в формате Base64
auth_str = f'{username}:{password}'
auth_bytes = auth_str.encode('utf-8')
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8')
url = os.getenv('GIGA_AUTH_URL')
headers = {
'Authorization': f'Basic {auth_base64}', # вставляем базовую авторизацию
'RqUID': os.getenv('GIGA_rquid'),
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
data = {
'scope': os.getenv('GIGA_SCOPE')
}
response = requests.post(url, headers=headers, data=data, verify=False)
access_token = response.json()['access_token']
print('Got access token')
url_completion = os.getenv('GIGA_COMPLETION_URL')
data_copm = json.dumps({
"model": os.getenv('GIGA_MODEL'),
"messages": [
{
"role": "user",
"content": os.getenv('GIGA_BASE_PROMPT') + transcript
}
],
"stream": False,
"max_tokens": int(os.getenv('GIGA_MAX_TOKENS')),
})
headers_comp = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': 'Bearer ' + access_token
}
response = requests.post(url_completion, headers=headers_comp, data=data_copm, verify=False)
response_data = response.json()
answer_from_llm = response_data['choices'][0]['message']['content']
st.write("Результат резюмирования:")
st.text(answer_from_llm)