carvekit / app.py
leonelhs's picture
add net choices
699e9ad
raw
history blame contribute delete
No virus
3.44 kB
import gradio as gr
import torch
from carvekit.api.interface import Interface
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.ml.wrap.u2net import U2NET
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator
device = 'cuda' if torch.cuda.is_available() else 'cpu'
segment_net = {
"U2NET": U2NET(device=device, batch_size=1),
"BASNET": BASNET(device=device, batch_size=1),
"DeepLabV3": DeepLabV3(device=device, batch_size=1),
"TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1)
}
fba = FBAMatting(device=device,
input_tensor_size=2048,
batch_size=1)
trimap = TrimapGenerator()
preprocessing = PreprocessingStub()
postprocessing = MattingMethod(matting_module=fba,
trimap_generator=trimap,
device=device)
method_choices = [k for k, v in segment_net.items()]
def generate_trimap(method, original):
mask = segment_net[method]([original])
return trimap(original_image=original, mask=mask[0])
def predict(method, image):
method = segment_net[method]
return Interface(pre_pipe=preprocessing,
post_pipe=postprocessing,
seg_pipe=method)([image])[0]
footer = r"""
<center>
<img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80">
</br>
<b>
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a>
</b>
</center>
"""
with gr.Blocks(title="CarveKit") as app:
gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
gr.HTML("<center><h3>High-quality image background removal</h3></center>")
with gr.Tabs() as tabs:
with gr.TabItem("Remove background", id=0):
with gr.Row(equal_height=False):
with gr.Column():
input_img = gr.Image(type="pil", label="Input image")
drp_itf = gr.Dropdown(
value="TracerUniversalB7",
label="Segmentor model",
choices=method_choices)
run_btn = gr.Button(variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="result")
run_btn.click(predict, [drp_itf, input_img], [output_img])
with gr.TabItem("Trimap generator", id=1):
with gr.Row(equal_height=False):
with gr.Column():
trimap_input = gr.Image(type="pil", label="Input image")
drp_itf = gr.Dropdown(
value="TracerUniversalB7",
label="Segmentor model",
choices=method_choices)
trimap_btn = gr.Button(variant="primary")
with gr.Column():
trimap_output = gr.Image(type="pil", label="result")
trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output])
with gr.Row():
gr.HTML(footer)
app.launch(share=False, debug=True, enable_queue=True, show_error=True)