VFusion3D / app.py
JunlinHan's picture
Create app.py
68ae2ac verified
raw
history blame
No virus
5.34 kB
import torch
import gradio as gr
import os
import numpy as np
import trimesh
import mcubes
from torchvision.utils import save_image
from PIL import Image
from transformers import AutoModel, AutoConfig
from rembg import remove, new_session
from functools import partial
from kiui.op import recenter
import kiui
# we load the pre-trained model from HF
class LRMGeneratorWrapper:
def __init__(self):
self.config = AutoConfig.from_pretrained("jadechoghari/custom-llrm", trust_remote_code=True)
self.model = AutoModel.from_pretrained("jadechoghari/custom-llrm", trust_remote_code=True)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def forward(self, image, camera):
return self.model(image, camera)
model_wrapper = LRMGeneratorWrapper()
def preprocess_image(image, source_size):
session = new_session("isnet-general-use")
rembg_remove = partial(remove, session=session)
image = np.array(image)
image = rembg_remove(image)
mask = rembg_remove(image, only_mask=True)
image = recenter(image, mask, border_ratio=0.20)
image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
if image.shape[1] == 4:
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
image = torch.clamp(image, 0, 1)
return image
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
"""
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
Return batched fx, fy, cx, cy
"""
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
fx, fy = fx / width, fy / height
cx, cy = cx / width, cy / height
return fx, fy, cx, cy
def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
"""
RT: (N, 3, 4)
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
"""
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
return torch.cat([
RT.reshape(-1, 12),
fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
], dim=-1)
def _default_intrinsics():
fx = fy = 384
cx = cy = 256
w = h = 512
intrinsics = torch.tensor([
[fx, fy],
[cx, cy],
[w, h],
], dtype=torch.float32)
return intrinsics
def _default_source_camera(batch_size: int = 1):
dist_to_center = 1.5
canonical_camera_extrinsics = torch.tensor([[
[0, 0, 1, 1],
[1, 0, 0, 0],
[0, 1, 0, 0],
]], dtype=torch.float32)
canonical_camera_intrinsics = _default_intrinsics().unsqueeze(0)
source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
return source_camera.repeat(batch_size, 1)
#Ref: https://github.com/jadechoghari/vfusion3d/blob/main/lrm/inferrer.py
def generate_mesh(image, source_size=512, render_size=384, mesh_size=512, export_mesh=True):
image = preprocess_image(image, source_size).to(model_wrapper.device)
source_camera = _default_source_camera(batch_size=1).to(model_wrapper.device)
# TODO: export video we need render_camera
# render_camera = _default_render_cameras(batch_size=1).to(model_wrapper.device)
with torch.no_grad():
planes = model_wrapper.forward(image, source_camera)
if export_mesh:
grid_out = model_wrapper.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
vtx = vtx / (mesh_size - 1) * 2 - 1
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=model_wrapper.device).unsqueeze(0)
vtx_colors = model_wrapper.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
vtx_colors = (vtx_colors * 255).astype(np.uint8)
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh_path = "awesome_mesh.obj"
mesh.export(mesh_path, 'obj')
return mesh_path
# we will convert image to mesh
def step_1_generate_obj(image):
mesh_path = generate_mesh(image)
return mesh_path
# we will convert mesh to 3d-image
def step_2_display_3d_model(mesh_file):
return mesh_file
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", label="Input Image")
generate_button = gr.Button("Generate and Visualize 3D Model")
obj_file_output = gr.File(label="Download .obj File")
with gr.Column():
model_output = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model Visualization")
def generate_and_visualize(image):
mesh_path = step_1_generate_obj(image)
return mesh_path, mesh_path
generate_button.click(generate_and_visualize, inputs=img_input, outputs=[obj_file_output, model_output])
demo.launch()