# Prediction interface for Cog ⚙️ # https://github.com/replicate/cog/blob/main/docs/python.md import os import subprocess from cog import BasePredictor, Input, Path import inference from time import time from functools import wraps import torch def make_mem_efficient(cls: BasePredictor): if not torch.cuda.is_available(): return cls old_setup = cls.setup old_predict = cls.predict @wraps(old_setup) def new_setup(self, *args, **kwargs): ret = old_setup(self, *args, **kwargs) _move_to(self, "cpu") return ret @wraps(old_predict) def new_predict(self, *args, **kwargs): _move_to(self, "cuda") try: ret = old_predict(self, *args, **kwargs) finally: _move_to(self, "cpu") return ret cls.setup = new_setup cls.predict = new_predict return cls def _move_to(self, device): try: self = self.cached_models except AttributeError: pass for attr, value in vars(self).items(): try: value = value.to(device) except AttributeError: pass else: print(f"Moving {self.__name__}.{attr} to {device}") setattr(self, attr, value) torch.cuda.empty_cache() @make_mem_efficient class Predictor(BasePredictor): cached_models = inference def setup(self): inference.do_load("checkpoints/wav2lip_gan.pth") def predict( self, face: Path = Input(description="video/image that contains faces to use"), audio: Path = Input(description="video/audio file to use as raw audio source"), pads: str = Input( description="Padding for the detected face bounding box.\n" "Please adjust to include chin at least\n" 'Format: "top bottom left right"', default="0 10 0 0", ), smooth: bool = Input( description="Smooth face detections over a short temporal window", default=True, ), fps: float = Input( description="Can be specified only if input is a static image", default=25.0, ), out_height: int = Input( description="Output video height. Best results are obtained at 480 or 720", default=480, ), ) -> Path: try: os.remove("results/result_voice.mp4") except FileNotFoundError: pass face_ext = os.path.splitext(face)[-1] if face_ext not in [".mp4", ".mov", ".png" , ".jpg" , ".jpeg" , ".gif", ".mkv", ".webp"]: raise ValueError(f'Unsupported face format {face_ext!r}') audio_ext = os.path.splitext(audio)[-1] if audio_ext not in [".wav", ".mp3"]: raise ValueError(f'Unsupported audio format {audio_ext!r}') args = [ "--checkpoint_path", "checkpoints/wav2lip_gan.pth", "--face", str(face), "--audio", str(audio), "--pads", *pads.split(" "), "--fps", str(fps), "--out_height", str(out_height), ] if not smooth: args += ["--nosmooth"] print("-> run:", " ".join(args)) inference.args = inference.parser.parse_args(args) s = time() try: inference.main() except ValueError as e: print('-> Encountered error, skipping lipsync:', e) args = [ "ffmpeg", "-y", # "-vsync", "0", "-hwaccel", "cuda", "-hwaccel_output_format", "cuda", "-stream_loop", "-1", "-i", str(face), "-i", str(audio), "-shortest", "-fflags", "+shortest", "-max_interleave_delta", "100M", "-map", "0:v:0", "-map", "1:a:0", # "-c", "copy", # "-c:v", "h264_nvenc", "results/result_voice.mp4", ] print("-> run:", " ".join(args)) print(subprocess.check_output(args, encoding="utf-8")) print(time() - s) return Path("results/result_voice.mp4")