zxl
first commit
07c6a04
raw
history blame
No virus
4.68 kB
import os
from functools import partial
from typing import Any, Optional
import imageio
import torch
import videosys
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
class VideoSysEngine:
"""
this is partly inspired by vllm
"""
def __init__(self, config):
self.config = config
self.parallel_worker_tasks = None
self._init_worker(config.pipeline_cls)
def _init_worker(self, pipeline_cls):
world_size = self.config.world_size
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
# contention amongst the shards
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"
# NOTE: The two following lines need adaption for multi-node
assert world_size <= torch.cuda.device_count()
# change addr for multi-node
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
if world_size == 1:
self.workers = []
self.worker_monitor = None
else:
result_handler = ResultHandler()
self.workers = [
ProcessWorkerWrapper(
result_handler,
partial(
self._create_pipeline,
pipeline_cls=pipeline_cls,
rank=rank,
local_rank=rank,
distributed_init_method=distributed_init_method,
),
)
for rank in range(1, world_size)
]
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
result_handler.start()
self.worker_monitor.start()
self.driver_worker = self._create_pipeline(
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
)
# TODO: add more options here for pipeline, or wrap all options into config
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
videosys.initialize(rank=rank, world_size=self.config.world_size, init_method=distributed_init_method, seed=42)
pipeline = pipeline_cls(self.config)
return pipeline
def _run_workers(
self,
method: str,
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
# Start the workers first.
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
if async_run_tensor_parallel_workers_only:
# Just return futures
return worker_outputs
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
# Get the results of the workers.
return [driver_worker_output] + [output.get() for output in worker_outputs]
def _driver_execute_model(self, *args, **kwargs):
return self.driver_worker.generate(*args, **kwargs)
def generate(self, *args, **kwargs):
return self._run_workers("generate", *args, **kwargs)[0]
def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
# (this will raise otherwise)
self._wait_for_tasks_completion(parallel_worker_tasks)
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
for result in parallel_worker_tasks:
result.get()
def save_video(self, video, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
imageio.mimwrite(output_path, video, fps=24)
def shutdown(self):
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
worker_monitor.close()
torch.distributed.destroy_process_group()
def __del__(self):
self.shutdown()