Sabbah13's picture
Update app.py
b688722 verified
raw
history blame
No virus
6.95 kB
import requests
import base64
import os
import json
import streamlit as st
import whisperx
import torch
def convert_segments_object_to_text(data):
result = []
for segment in data['segments']:
words = segment['words']
segment_speaker = segment.get('speaker', None)
segment_start = segment.get('start', None)
segment_end = segment.get('end', None)
current_speaker = None
current_start = None
current_end = None
current_text = []
# Forward fill speaker, start and end if missing
for i, word_info in enumerate(words):
if 'speaker' not in word_info:
if i > 0 and 'speaker' in words[i - 1]:
word_info['speaker'] = words[i - 1]['speaker']
elif i < len(words) - 1 and 'speaker' in words[i + 1]:
word_info['speaker'] = words[i + 1]['speaker']
else:
word_info['speaker'] = segment_speaker
if 'start' not in word_info:
if i > 0 and 'end' in words[i - 1]:
word_info['start'] = words[i - 1]['end']
else:
word_info['start'] = segment_start
if 'end' not in word_info:
if i < len(words) - 1 and 'start' in words[i + 1]:
word_info['end'] = words[i + 1]['start']
elif i == len(words) - 1:
word_info['end'] = segment_end
else:
word_info['end'] = word_info['start']
for word_info in words:
word = word_info.get('word', '')
start = word_info.get('start', None)
end = word_info.get('end', None)
speaker = word_info.get('speaker', None)
if current_speaker is None:
current_speaker = speaker
current_start = start
if speaker == current_speaker:
current_text.append(word)
current_end = end
else:
# Finish current segment
if current_start is not None and current_end is not None:
formatted_text = f'{current_speaker} ({current_start} : {current_end}) : {" ".join(current_text)}'
else:
formatted_text = f'{current_speaker} : {" ".join(current_text)}'
result.append(formatted_text)
# Start new segment
current_speaker = speaker
current_start = start
current_end = end
current_text = [word]
# Append the last segment
if current_text:
if current_start is not None and current_end is not None:
formatted_text = f'{current_speaker} ({current_start} : {current_end}) : {" ".join(current_text)}'
else:
formatted_text = f'{current_speaker} : {" ".join(current_text)}'
result.append(formatted_text)
return '\n'.join(result)
st.title('Audio Transcription App')
st.sidebar.title("Settings")
# Sidebar inputs
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1)
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16)
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0)
ACCESS_TOKEN = st.secrets["HF_TOKEN"]
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])
if uploaded_file is not None:
st.audio(uploaded_file)
file_extension = uploaded_file.name.split(".")[-1] # Получаем расширение файла
temp_file_path = f"temp_file.{file_extension}" # Создаем временное имя файла с правильным расширением
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
with st.spinner('Транскрибируем...'):
# Load model
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
# Load and transcribe audio
audio = whisperx.load_audio(temp_file_path)
result = model.transcribe(audio, batch_size=batch_size, language="ru")
print('Transcribed, now aligning')
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
print('Aligned, now diarizing')
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
diarize_segments = diarize_model(audio)
result_diar = whisperx.assign_word_speakers(diarize_segments, result)
st.write("Результат транскрибации:")
transcript = convert_segments_object_to_text(result_diar)
st.text(transcript)
with st.spinner('Резюмируем...'):
username = st.secrets["GIGA_USERNAME"]
password = st.secrets["GIGA_SECRET"]
# Получаем строку с базовой авторизацией в формате Base64
auth_str = f'{username}:{password}'
auth_bytes = auth_str.encode('utf-8')
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8')
url = os.getenv('GIGA_AUTH_URL')
headers = {
'Authorization': f'Basic {auth_base64}', # вставляем базовую авторизацию
'RqUID': os.getenv('GIGA_rquid'),
'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'
}
data = {
'scope': os.getenv('GIGA_SCOPE')
}
response = requests.post(url, headers=headers, data=data, verify=False)
access_token = response.json()['access_token']
print('Got access token')
url_completion = os.getenv('GIGA_COMPLETION_URL')
data_copm = json.dumps({
"model": os.getenv('GIGA_MODEL'),
"messages": [
{
"role": "user",
"content": os.getenv('GIGA_BASE_PROMPT') + transcript
}
],
"stream": False,
"max_tokens": int(os.getenv('GIGA_MAX_TOKENS')),
})
headers_comp = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': 'Bearer ' + access_token
}
response = requests.post(url_completion, headers=headers_comp, data=data_copm, verify=False)
response_data = response.json()
answer_from_llm = response_data['choices'][0]['message']['content']
st.write("Результат резюмирования:")
st.text(answer_from_llm)