Sabbah13's picture
Update app.py
7b9fbc2 verified
raw
history blame
No virus
6.41 kB
import requests
import base64
import os
import json
import streamlit as st
import whisperx
import torch
from utils import convert_segments_object_to_text
def get_completion_from_gigachat(prompt, max_tokens, access_token):
url_completion = os.getenv('GIGA_COMPLETION_URL')
data_copm = json.dumps({
"model": os.getenv('GIGA_MODEL'),
"messages": [
{
"role": "user",
"content": prompt
}
],
"stream": False,
"max_tokens": max_tokens,
})
headers_comp = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': 'Bearer ' + access_token
}
response = requests.post(url_completion, headers=headers_comp, data=data_copm, verify=False)
response_data = response.json()
answer_from_llm = response_data['choices'][0]['message']['content']
return answer_from_llm
st.title('Audio Transcription App')
st.sidebar.title("Settings")
# Sidebar inputs
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1)
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16)
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0)
initial_giga_base_prompt = "Напиши резюме транскрибации звонка, текст которого приложен в ниже. Выдели самостоятельно цель встречи, потом описать ключевые моменты всей встречи. Потом выделить отдельные темы звонка и выделить ключевые моменты в них. Напиши итоги того, о чем договорились говорящие, если такое возможно выделить из текста.\nТранскрибация: "
initial_giga_processing_prompt = "Обработай транкрибацию звонка. Убедись, что каждое слово назначено правильному спикеру. Если заметишь, что слово или фраза ошибочно приписаны другому спикеру, исправь это. Постарайся понять имена говорящих из контекста разговора и замени «SPEAKER_00», «SPEAKER_01» и т.д. на их реальные имена. Если чье-то имя понять невозможно, то не меняй его. Приложи в ответе обработанную транскрибацию\nТранскрибация: "
giga_base_prompt = st.sidebar.text_area("Промпт ГигаЧата для резюмирования", value=initial_giga_base_prompt)
giga_max_tokens = st.sidebar.number_input("Максимальное количество токенов при резюмировании", min_value=1, value=1024)
enable_summarization = st.sidebar.checkbox("Добавить обработку транскрибации", value=False)
giga_processing_prompt = st.sidebar.text_area("Промпт ГигаЧата для обработки транскрибации", value=initial_giga_processing_prompt)
ACCESS_TOKEN = st.secrets["HF_TOKEN"]
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])
if uploaded_file is not None:
st.audio(uploaded_file)
file_extension = uploaded_file.name.split(".")[-1] # Получаем расширение файла
temp_file_path = f"temp_file.{file_extension}" # Создаем временное имя файла с правильным расширением
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with st.spinner('Транскрибируем...'):
# Load model
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
# Load and transcribe audio
audio = whisperx.load_audio(temp_file_path)
result = model.transcribe(audio, batch_size=batch_size, language="ru")
print('Transcribed, now aligning')
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
print('Aligned, now diarizing')
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
diarize_segments = diarize_model(audio)
result_diar = whisperx.assign_word_speakers(diarize_segments, result)
st.write("Результат транскрибации:")
transcript = convert_segments_object_to_text(result_diar)
st.text(transcript)
with st.spinner('Обрабатываем транскрибацию...'):
username = st.secrets["GIGA_USERNAME"]
password = st.secrets["GIGA_SECRET"]
# Получаем строку с базовой авторизацией в формате Base64
auth_str = f'{username}:{password}'
auth_bytes = auth_str.encode('utf-8')
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8')
url = os.getenv('GIGA_AUTH_URL')
headers = {
'Authorization': f'Basic {auth_base64}', # вставляем базовую авторизацию
'RqUID': os.getenv('GIGA_rquid'),
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
data = {
'scope': os.getenv('GIGA_SCOPE')
}
response = requests.post(url, headers=headers, data=data, verify=False)
access_token = response.json()['access_token']
print('Got access token')
transcribe_answer = get_completion_from_gigachat(giga_processing_prompt + transcript, 32768, access_token)
st.write("Результат обработки:")
st.text(transcribe_answer)
with st.spinner('Резюмируем...'):
summary_answer = get_completion_from_gigachat(giga_base_prompt + transcribe_answer, giga_max_tokens, access_token)
st.write("Результат резюмирования:")
st.text(summary_answer)