pragnakalp commited on
Commit
7f3a4c0
1 Parent(s): 125c3bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -54
app.py CHANGED
@@ -1,61 +1,57 @@
1
- # import gradio as gr
2
- # import os, subprocess, torchaudio
3
- # import torch
4
- # from PIL import Image
5
-
6
- # block = gr.Blocks()
7
-
8
- # def pad_image(image):
9
- # w, h = image.size
10
- # if w == h:
11
- # return image
12
- # elif w > h:
13
- # new_image = Image.new(image.mode, (w, w), (0, 0, 0))
14
- # new_image.paste(image, (0, (w - h) // 2))
15
- # return new_image
16
- # else:
17
- # new_image = Image.new(image.mode, (h, h), (0, 0, 0))
18
- # new_image.paste(image, ((h - w) // 2, 0))
19
- # return new_image
20
-
21
- # def calculate(image_in, audio_in):
22
- # waveform, sample_rate = torchaudio.load(audio_in)
23
- # waveform = torch.mean(waveform, dim=0, keepdim=True)
24
- # torchaudio.save("/content/audio.wav", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
25
- # image = Image.open(image_in)
26
- # image = pad_image(image)
27
- # image.save("image.png")
28
-
29
- # pocketsphinx_run = subprocess.run(['pocketsphinx', '-phone_align', 'yes', 'single', '/content/audio.wav'], check=True, capture_output=True)
30
- # jq_run = subprocess.run(['jq', '[.w[]|{word: (.t | ascii_upcase | sub("<S>"; "sil") | sub("<SIL>"; "sil") | sub("\\\(2\\\)"; "") | sub("\\\(3\\\)"; "") | sub("\\\(4\\\)"; "") | sub("\\\[SPEECH\\\]"; "SIL") | sub("\\\[NOISE\\\]"; "SIL")), phones: [.w[]|{ph: .t | sub("\\\+SPN\\\+"; "SIL") | sub("\\\+NSN\\\+"; "SIL"), bg: (.b*100)|floor, ed: (.b*100+.d*100)|floor}]}]'], input=pocketsphinx_run.stdout, capture_output=True)
31
- # with open("test.json", "w") as f:
32
- # f.write(jq_run.stdout.decode('utf-8').strip())
33
- # # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
- # os.system(f"cd /content/one-shot-talking-face && python3 -B test_script.py --img_path /content/image.png --audio_path /content/audio.wav --phoneme_path /content/test.json --save_dir /content/train")
35
- # return "/content/train/image_audio.mp4"
36
 
37
- # def run():
38
- # with block:
39
 
40
- # with gr.Group():
41
- # with gr.Box():
42
- # with gr.Row().style(equal_height=True):
43
- # image_in = gr.Image(show_label=False, type="filepath")
44
- # audio_in = gr.Audio(show_label=False, type='filepath')
45
- # video_out = gr.Video(show_label=False)
46
- # with gr.Row().style(equal_height=True):
47
- # btn = gr.Button("Generate")
48
 
49
 
50
- # btn.click(calculate, inputs=[image_in, audio_in], outputs=[video_out])
51
- # block.queue()
52
- # block.launch(server_name="0.0.0.0", server_port=7860)
53
 
54
- # if __name__ == "__main__":
55
- # run()
56
 
57
- import torch
58
- print(torch.cuda.is_available())
59
 
60
- print(torch.cuda.device_count())
61
- print(torch.device('cpu'))
 
1
+ import gradio as gr
2
+ import os, subprocess, torchaudio
3
+ import torch
4
+ from PIL import Image
5
+
6
+ block = gr.Blocks()
7
+
8
+ def pad_image(image):
9
+ w, h = image.size
10
+ if w == h:
11
+ return image
12
+ elif w > h:
13
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
14
+ new_image.paste(image, (0, (w - h) // 2))
15
+ return new_image
16
+ else:
17
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
18
+ new_image.paste(image, ((h - w) // 2, 0))
19
+ return new_image
20
+
21
+ def calculate(image_in, audio_in):
22
+ waveform, sample_rate = torchaudio.load(audio_in)
23
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
24
+ torchaudio.save("/content/audio.wav", waveform, sample_rate, encoding="PCM_S", bits_per_sample=16)
25
+ image = Image.open(image_in)
26
+ image = pad_image(image)
27
+ image.save("image.png")
28
+
29
+ pocketsphinx_run = subprocess.run(['pocketsphinx', '-phone_align', 'yes', 'single', '/content/audio.wav'], check=True, capture_output=True)
30
+ jq_run = subprocess.run(['jq', '[.w[]|{word: (.t | ascii_upcase | sub("<S>"; "sil") | sub("<SIL>"; "sil") | sub("\\\(2\\\)"; "") | sub("\\\(3\\\)"; "") | sub("\\\(4\\\)"; "") | sub("\\\[SPEECH\\\]"; "SIL") | sub("\\\[NOISE\\\]"; "SIL")), phones: [.w[]|{ph: .t | sub("\\\+SPN\\\+"; "SIL") | sub("\\\+NSN\\\+"; "SIL"), bg: (.b*100)|floor, ed: (.b*100+.d*100)|floor}]}]'], input=pocketsphinx_run.stdout, capture_output=True)
31
+ with open("test.json", "w") as f:
32
+ f.write(jq_run.stdout.decode('utf-8').strip())
33
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+ os.system(f"cd /content/one-shot-talking-face && python3 -B test_script.py --img_path /content/image.png --audio_path /content/audio.wav --phoneme_path /content/test.json --save_dir /content/train")
35
+ return "/content/train/image_audio.mp4"
36
 
37
+ def run():
38
+ with block:
39
 
40
+ with gr.Group():
41
+ with gr.Box():
42
+ with gr.Row().style(equal_height=True):
43
+ image_in = gr.Image(show_label=False, type="filepath")
44
+ audio_in = gr.Audio(show_label=False, type='filepath')
45
+ video_out = gr.Video(show_label=False)
46
+ with gr.Row().style(equal_height=True):
47
+ btn = gr.Button("Generate")
48
 
49
 
50
+ btn.click(calculate, inputs=[image_in, audio_in], outputs=[video_out])
51
+ block.queue()
52
+ block.launch(server_name="0.0.0.0", server_port=7860)
53
 
54
+ if __name__ == "__main__":
55
+ run()
56
 
 
 
57