yerang commited on
Commit
32e6b3a
1 Parent(s): bd30786

Upload stf/Untitled.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. stf/Untitled.ipynb +140 -0
stf/Untitled.ipynb ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "id": "f3e86ccf-bfe9-48f0-a802-ff96a0bd3323",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import stf_alternative"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 12,
16
+ "id": "0f33c097-c51f-4855-a632-52f1bba575f5",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "config_path = \"front_config.json\"\n",
21
+ "checkpoint_path = \"089.pth\"\n",
22
+ "work_root_path = \"works\"\n",
23
+ "device = \"cuda:0\"\n",
24
+ "\n",
25
+ "model = stf_alternative.create_model(\n",
26
+ " config_path=config_path,\n",
27
+ " checkpoint_path=checkpoint_path,\n",
28
+ " work_root_path=work_root_path,\n",
29
+ " device=device,\n",
30
+ " wavlm_path=\"microsoft/wavlm-large\",\n",
31
+ ")\n",
32
+ "template = stf_alternative.Template(\n",
33
+ " model=model,\n",
34
+ " config_path=config_path,\n",
35
+ " template_video_path=\"templates/front_one_piece_dress_nodded_cut.webm\",\n",
36
+ ")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 17,
42
+ "id": "0b4aae64-0e71-445d-9999-93f3790c61dd",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "from pydub import AudioSegment\n",
47
+ "silent = AudioSegment.silent(2000)\n",
48
+ "gen_infer = template.gen_infer(audio_segment=silent,video_start_offset_frame=0)\n",
49
+ "from PIL import Image\n",
50
+ "for pred, chunk in gen_infer:\n",
51
+ " break"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 36,
57
+ "id": "fc8dec84-ebf4-4b8c-a4a3-4c2cd6fdc891",
58
+ "metadata": {},
59
+ "outputs": [
60
+ {
61
+ "data": {
62
+ "text/plain": [
63
+ "torch.Size([31, 1024])"
64
+ ]
65
+ },
66
+ "execution_count": 36,
67
+ "metadata": {},
68
+ "output_type": "execute_result"
69
+ }
70
+ ],
71
+ "source": [
72
+ "import torch\n",
73
+ "model.audio_encoder(input_values=torch.randn(1, 10000).cuda()).last_hidden_state[0].shape"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 38,
79
+ "id": "a47a4d7d-8e46-409a-a2f6-a6b875d2c877",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "pred = model.model(\n",
84
+ " img=torch.randn(1, 9, 352, 352).cuda(),\n",
85
+ " audio=torch.randn(1, 78, 1024).cuda(),\n",
86
+ ")"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 40,
92
+ "id": "087daeec-0bc6-4ee5-97aa-ca6b4c63f3e8",
93
+ "metadata": {},
94
+ "outputs": [
95
+ {
96
+ "data": {
97
+ "text/plain": [
98
+ "torch.Size([1, 3, 352, 352])"
99
+ ]
100
+ },
101
+ "execution_count": 40,
102
+ "metadata": {},
103
+ "output_type": "execute_result"
104
+ }
105
+ ],
106
+ "source": [
107
+ "pred.shape"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "da6489ab-8766-4ebc-ae30-809a4a366aa2",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": []
117
+ }
118
+ ],
119
+ "metadata": {
120
+ "kernelspec": {
121
+ "display_name": "Python 3 (ipykernel)",
122
+ "language": "python",
123
+ "name": "python3"
124
+ },
125
+ "language_info": {
126
+ "codemirror_mode": {
127
+ "name": "ipython",
128
+ "version": 3
129
+ },
130
+ "file_extension": ".py",
131
+ "mimetype": "text/x-python",
132
+ "name": "python",
133
+ "nbconvert_exporter": "python",
134
+ "pygments_lexer": "ipython3",
135
+ "version": "3.10.12"
136
+ }
137
+ },
138
+ "nbformat": 4,
139
+ "nbformat_minor": 5
140
+ }