{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "f3e86ccf-bfe9-48f0-a802-ff96a0bd3323", "metadata": {}, "outputs": [], "source": [ "import stf_alternative" ] }, { "cell_type": "code", "execution_count": 12, "id": "0f33c097-c51f-4855-a632-52f1bba575f5", "metadata": {}, "outputs": [], "source": [ "config_path = \"front_config.json\"\n", "checkpoint_path = \"089.pth\"\n", "work_root_path = \"works\"\n", "device = \"cuda:0\"\n", "\n", "model = stf_alternative.create_model(\n", " config_path=config_path,\n", " checkpoint_path=checkpoint_path,\n", " work_root_path=work_root_path,\n", " device=device,\n", " wavlm_path=\"microsoft/wavlm-large\",\n", ")\n", "template = stf_alternative.Template(\n", " model=model,\n", " config_path=config_path,\n", " template_video_path=\"templates/front_one_piece_dress_nodded_cut.webm\",\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "id": "0b4aae64-0e71-445d-9999-93f3790c61dd", "metadata": {}, "outputs": [], "source": [ "from pydub import AudioSegment\n", "silent = AudioSegment.silent(2000)\n", "gen_infer = template.gen_infer(audio_segment=silent,video_start_offset_frame=0)\n", "from PIL import Image\n", "for pred, chunk in gen_infer:\n", " break" ] }, { "cell_type": "code", "execution_count": 36, "id": "fc8dec84-ebf4-4b8c-a4a3-4c2cd6fdc891", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([31, 1024])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "model.audio_encoder(input_values=torch.randn(1, 10000).cuda()).last_hidden_state[0].shape" ] }, { "cell_type": "code", "execution_count": 38, "id": "a47a4d7d-8e46-409a-a2f6-a6b875d2c877", "metadata": {}, "outputs": [], "source": [ "pred = model.model(\n", " img=torch.randn(1, 9, 352, 352).cuda(),\n", " audio=torch.randn(1, 78, 1024).cuda(),\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "id": "087daeec-0bc6-4ee5-97aa-ca6b4c63f3e8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 3, 352, 352])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "da6489ab-8766-4ebc-ae30-809a4a366aa2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }