Fvds / app.py
pragnakalp's picture
Update app.py
6743cbf
raw
history blame
3.17 kB
import gradio as gr
import os, subprocess, torchaudio
import torch
from PIL import Image
import gradio as gr
import os, subprocess, torchaudio
import torch
from PIL import Image
import soundfile
from gtts import gTTS
import tempfile
from pydub import AudioSegment
from pydub.generators import Sine
block = gr.Blocks()
def pad_image(image):
w, h = image.size
if w == h:
return image
elif w > h:
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
def calculate(image_in, audio_in):
waveform, sample_rate = torchaudio.load(audio_in)
waveform = torch.mean(waveform, dim=0, keepdim=True)
torchaudio.save("/content/audio.wav", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
image = Image.open(image_in)
image = pad_image(image)
image.save("image.png")
pocketsphinx_run = subprocess.run(['pocketsphinx', '-phone_align', 'yes', 'single', '/content/audio.wav'], check=True, capture_output=True)
jq_run = subprocess.run(['jq', '[.w[]|{word: (.t | ascii_upcase | sub("<S>"; "sil") | sub("<SIL>"; "sil") | sub("\\\(2\\\)"; "") | sub("\\\(3\\\)"; "") | sub("\\\(4\\\)"; "") | sub("\\\[SPEECH\\\]"; "SIL") | sub("\\\[NOISE\\\]"; "SIL")), phones: [.w[]|{ph: .t | sub("\\\+SPN\\\+"; "SIL") | sub("\\\+NSN\\\+"; "SIL"), bg: (.b*100)|floor, ed: (.b*100+.d*100)|floor}]}]'], input=pocketsphinx_run.stdout, capture_output=True)
with open("test.json", "w") as f:
f.write(jq_run.stdout.decode('utf-8').strip())
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.system(f"cd /content/one-shot-talking-face && python3 -B test_script.py --img_path /content/image.png --audio_path /content/audio.wav --phoneme_path /content/test.json --save_dir /content/train")
return "/content/train/image_audio.mp4"
def one_shot(image_in,audio_in,gender):
if gender == "Female":
tts = gTTS("Hello i am john")
with tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) as f:
tts.write_to_fp(f)
f.seek(0)
sound = AudioSegment.from_file(f.name, format="mp3")
sound.export("/content/audio.wav", format="wav")
audio_in="/content/audio.wav"
return audio_in
def run():
with block:
with gr.Group():
with gr.Box():
with gr.Row().style(equal_height=True):
image_in = gr.Image(show_label=False, type="filepath")
audio_in = gr.Audio(show_label=False, type='filepath')
gender = gr.Radio(["Female","Male"],value="Female",label="Gender")
video_out = gr.Audio(show_label=False)
with gr.Row().style(equal_height=True):
btn = gr.Button("Generate")
btn.click(one_shot, inputs=[image_in, audio_in,gender], outputs=[video_out])
block.queue()
block.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
run()