|
import tempfile |
|
|
|
import imageio |
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config |
|
from shap_e.diffusion.sample import sample_latents |
|
from shap_e.models.download import load_config, load_model |
|
from shap_e.models.nn.camera import (DifferentiableCameraBatch, |
|
DifferentiableProjectiveCamera) |
|
from shap_e.models.transmitter.base import Transmitter, VectorDecoder |
|
from shap_e.util.collections import AttrDict |
|
from shap_e.util.image_util import load_image |
|
|
|
|
|
|
|
def create_pan_cameras(size: int, |
|
device: torch.device) -> DifferentiableCameraBatch: |
|
origins = [] |
|
xs = [] |
|
ys = [] |
|
zs = [] |
|
for theta in np.linspace(0, 2 * np.pi, num=20): |
|
z = np.array([np.sin(theta), np.cos(theta), -0.5]) |
|
z /= np.sqrt(np.sum(z**2)) |
|
origin = -z * 4 |
|
x = np.array([np.cos(theta), -np.sin(theta), 0.0]) |
|
y = np.cross(z, x) |
|
origins.append(origin) |
|
xs.append(x) |
|
ys.append(y) |
|
zs.append(z) |
|
return DifferentiableCameraBatch( |
|
shape=(1, len(xs)), |
|
flat_camera=DifferentiableProjectiveCamera( |
|
origin=torch.from_numpy(np.stack(origins, |
|
axis=0)).float().to(device), |
|
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), |
|
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), |
|
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), |
|
width=size, |
|
height=size, |
|
x_fov=0.7, |
|
y_fov=0.7, |
|
), |
|
) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def decode_latent_images( |
|
xm: Transmitter | VectorDecoder, |
|
latent: torch.Tensor, |
|
cameras: DifferentiableCameraBatch, |
|
rendering_mode: str = 'stf', |
|
): |
|
decoded = xm.renderer.render_views( |
|
AttrDict(cameras=cameras), |
|
params=(xm.encoder if isinstance(xm, Transmitter) else |
|
xm).bottleneck_to_params(latent[None]), |
|
options=AttrDict(rendering_mode=rendering_mode, |
|
render_with_direction=False), |
|
) |
|
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() |
|
return [PIL.Image.fromarray(x) for x in arr] |
|
|
|
|
|
class Model: |
|
def __init__(self): |
|
self.device = torch.device( |
|
'cuda' if torch.cuda.is_available() else 'cpu') |
|
self.xm = load_model('transmitter', device=self.device) |
|
self.diffusion = diffusion_from_config(load_config('diffusion')) |
|
self.model_name = '' |
|
self.model = None |
|
|
|
def load_model(self, model_name: str) -> None: |
|
assert model_name in ['text300M', 'image300M'] |
|
if model_name == self.model_name: |
|
return |
|
self.model = load_model(model_name, device=self.device) |
|
self.model_name = model_name |
|
|
|
@staticmethod |
|
def to_video(frames: list[PIL.Image.Image], fps: int = 5) -> str: |
|
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) |
|
writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps) |
|
for frame in frames: |
|
writer.append_data(np.asarray(frame)) |
|
writer.close() |
|
return out_file.name |
|
|
|
def run_text(self, |
|
prompt: str, |
|
seed: int = 0, |
|
guidance_scale: float = 15.0, |
|
num_steps: int = 64, |
|
output_image_size: int = 64, |
|
render_mode: str = 'nerf') -> str: |
|
self.load_model('text300M') |
|
|
|
torch.manual_seed(seed) |
|
|
|
latents = sample_latents( |
|
batch_size=1, |
|
model=self.model, |
|
diffusion=self.diffusion, |
|
guidance_scale=guidance_scale, |
|
model_kwargs=dict(texts=[prompt]), |
|
progress=True, |
|
clip_denoised=True, |
|
use_fp16=True, |
|
use_karras=True, |
|
karras_steps=num_steps, |
|
sigma_min=1e-3, |
|
sigma_max=160, |
|
s_churn=0, |
|
) |
|
|
|
cameras = create_pan_cameras(output_image_size, self.device) |
|
frames = decode_latent_images(self.xm, |
|
latents[0], |
|
cameras, |
|
rendering_mode=render_mode) |
|
return self.to_video(frames) |
|
|
|
def run_image(self, |
|
image_path: str, |
|
seed: int = 0, |
|
guidance_scale: float = 3.0, |
|
num_steps: int = 64, |
|
output_image_size: int = 64, |
|
render_mode: str = 'nerf') -> str: |
|
self.load_model('image300M') |
|
|
|
torch.manual_seed(seed) |
|
|
|
image = load_image(image_path) |
|
|
|
latents = sample_latents( |
|
batch_size=1, |
|
model=self.model, |
|
diffusion=self.diffusion, |
|
guidance_scale=guidance_scale, |
|
model_kwargs=dict(images=[image]), |
|
progress=True, |
|
clip_denoised=True, |
|
use_fp16=True, |
|
use_karras=True, |
|
karras_steps=num_steps, |
|
sigma_min=1e-3, |
|
sigma_max=160, |
|
s_churn=0, |
|
) |
|
|
|
cameras = create_pan_cameras(output_image_size, self.device) |
|
frames = decode_latent_images(self.xm, |
|
latents[0], |
|
cameras, |
|
rendering_mode=render_mode) |
|
return self.to_video(frames) |
|
|