3DGen-Arena / model /model_worker.py
ZhangYuhan's picture
model_worker
b7f6a68
raw
history blame
No virus
5.09 kB
import os
import json
import time
import kiui
from typing import List
import replicate
import subprocess
from constants import OFFLINE_GIF_DIR
# os.environ("REPLICATE_API_TOKEN", "r8_0BaoQW0G8nWFXY8YWBCCUDurANxCtY72rarv9")
class BaseModelWorker:
def __init__(self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str = None
):
self.model_name = model_name
self.i2s_model = i2s_model
self.online_model = online_model
self.model_api = model_api
self.urls_json = None
urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json")
if os.path.exists(urls_json_path):
with open(urls_json_path, 'r') as f:
self.urls_json = json.load(f)
def check_online(self) -> bool:
if self.online_model and not self.model:
return True
else:
return False
def load_offline(self, offline: bool, offline_idx):
## offline
if offline and str(offline_idx) in self.urls_json.keys():
return self.urls_json[str(offline_idx)]
else:
return None
def inference(self, prompt):
pass
def render(self, shape, rgb_on=True, normal_on=True):
pass
class HuggingfaceApiWorker(BaseModelWorker):
def __init__(
self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str,
):
super().__init__(
model_name,
i2s_model,
online_model,
model_api,
)
class PointE_Worker(BaseModelWorker):
def __init__(self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str):
super().__init__(model_name, i2s_model, online_model, model_api)
class TriplaneGaussian(BaseModelWorker):
def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None):
super().__init__(model_name, i2s_model, online_model, model_api)
class LGM_Worker(BaseModelWorker):
def __init__(self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
):
super().__init__(model_name, i2s_model, online_model, model_api)
self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
def inference(self, image):
output = self.model_client.run(
self.model_api,
input={"input_image": image}
)
#=> .mp4 .ply
return output[1]
def render(self, shape):
mesh = Gau2Mesh_client.run(shape)
path_normal = ""
cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \
--wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal"
subprocess.run(cmd_normal, shell=True, check=True)
path_rgb = ""
cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \
--wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb"
subprocess.run(cmd_rgb, shell=True, check=True)
return path_normal, path_rgb
class V3D_Worker(BaseModelWorker):
def __init__(self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str = None):
super().__init__(model_name, i2s_model, online_model, model_api)
# model = 'LGM'
# # model = 'TriplaneGaussian'
# folder = 'glbs_full'
# form = 'glb'
# pose = '+z'
# pair = ('OpenLRM', 'meshes', 'obj', '-y')
# pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y')
# pair = ('LGM', 'glbs_full', 'glb', '+z')
if __name__=="__main__":
# input = {
# "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png",
# }
# print("Start...")
# model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
# output = model_client.run(
# "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
# input=input
# )
# print("output: ", output)
#=> ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply']
output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply']
to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True)
mesh = to_mesh_client.predict(output[1], api_name="/run")
print(mesh)