Sabbah13's picture
Update app.py
0bf6fb0 verified
raw
history blame
No virus
4.33 kB
import requests
import base64
import os
import json
import streamlit as st
import whisperx
import torch
from utils import convert_segments_object_to_text, get_completion_from_gigachat
st.title('Audio Transcription App')
st.sidebar.title("Settings")
# Sidebar inputs
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1)
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16)
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0)
initial_giga_base_prompt = os.getenv('GIGA_BASE_PROMPT')
initial_giga_processing_prompt = os.getenv('GIGA_PROCCESS_PROMPT')
giga_base_prompt = st.sidebar.text_area("Промпт ГигаЧата для резюмирования", value=initial_giga_base_prompt)
giga_max_tokens = st.sidebar.number_input("Максимальное количество токенов при резюмировании", min_value=1, value=1024)
enable_summarization = st.sidebar.checkbox("Добавить обработку транскрибации", value=False)
giga_processing_prompt = st.sidebar.text_area("Промпт ГигаЧата для обработки транскрибации", value=initial_giga_processing_prompt)
ACCESS_TOKEN = st.secrets["HF_TOKEN"]
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])
if uploaded_file is not None:
st.audio(uploaded_file)
file_extension = uploaded_file.name.split(".")[-1] # Получаем расширение файла
temp_file_path = f"temp_file.{file_extension}" # Создаем временное имя файла с правильным расширением
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with st.spinner('Транскрибируем...'):
# Load model
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
# Load and transcribe audio
audio = whisperx.load_audio(temp_file_path)
result = model.transcribe(audio, batch_size=batch_size, language="ru")
print('Transcribed, now aligning')
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
print('Aligned, now diarizing')
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
diarize_segments = diarize_model(audio)
result_diar = whisperx.assign_word_speakers(diarize_segments, result)
st.write("Результат транскрибации:")
transcript = convert_segments_object_to_text(result_diar)
st.text(transcript)
username = st.secrets["GIGA_USERNAME"]
password = st.secrets["GIGA_SECRET"]
# Получаем строку с базовой авторизацией в формате Base64
auth_str = f'{username}:{password}'
auth_bytes = auth_str.encode('utf-8')
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8')
url = os.getenv('GIGA_AUTH_URL')
headers = {
'Authorization': f'Basic {auth_base64}', # вставляем базовую авторизацию
'RqUID': os.getenv('GIGA_rquid'),
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
data = {
'scope': os.getenv('GIGA_SCOPE')
}
response = requests.post(url, headers=headers, data=data, verify=False)
access_token = response.json()['access_token']
print('Got access token')
if (enable_summarization):
with st.spinner('Обрабатываем транскрибацию...'):
transcript = get_completion_from_gigachat(giga_processing_prompt + transcript, 32768, access_token)
st.write("Результат обработки:")
st.text(transcript)
with st.spinner('Резюмируем...'):
summary_answer = get_completion_from_gigachat(giga_base_prompt + transcript, giga_max_tokens, access_token)
st.write("Результат резюмирования:")
st.text(summary_answer)