|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
from transformers import pipeline |
|
from torchvision import models, transforms |
|
|
|
|
|
text_to_image_pipeline = pipeline("text-to-image-generation", model="CompVis/stable-diffusion-v1-4") |
|
segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True) |
|
segmentation_model.eval() |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def segment_clothing(image): |
|
input_tensor = preprocess(image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = segmentation_model(input_batch)['out'][0] |
|
output_predictions = output.argmax(0) |
|
mask = output_predictions.byte().cpu().numpy() |
|
|
|
return mask |
|
|
|
|
|
def generate_base_image(base_prompt_part1, base_prompt_color, base_prompt_clothing): |
|
|
|
base_prompt = f"{base_prompt_part1} {base_prompt_color} {base_prompt_clothing}" |
|
|
|
|
|
base_image = text_to_image_pipeline(base_prompt)[0] |
|
base_image = Image.fromarray(base_image) |
|
|
|
return base_image |
|
|
|
|
|
def generate_and_paste_design(base_image, design_prompt): |
|
|
|
generated_image = text_to_image_pipeline(design_prompt)[0] |
|
generated_design = Image.fromarray(generated_image) |
|
|
|
|
|
clothing_mask = segment_clothing(base_image) |
|
|
|
|
|
generated_design = generated_design.resize(base_image.size) |
|
|
|
|
|
clothing_area = Image.composite(generated_design, base_image, Image.fromarray(clothing_mask*255)) |
|
|
|
return clothing_area |
|
|
|
|
|
base_prompt_part1_input = gr.inputs.Textbox(lines=1, placeholder="Enter 'a single plain'") |
|
base_prompt_color_input = gr.inputs.Textbox(lines=1, placeholder="Enter color type") |
|
base_prompt_clothing_input = gr.inputs.Textbox(lines=1, placeholder="Enter clothing type") |
|
design_prompt_input = gr.inputs.Textbox(lines=1, placeholder="Enter design prompt") |
|
output_image = gr.outputs.Image(type="pil") |
|
|
|
def full_process(base_prompt_part1, base_prompt_color, base_prompt_clothing, design_prompt): |
|
|
|
base_image = generate_base_image(base_prompt_part1, base_prompt_color, base_prompt_clothing) |
|
|
|
|
|
final_image = generate_and_paste_design(base_image, design_prompt) |
|
|
|
return final_image |
|
|
|
gr.Interface( |
|
fn=full_process, |
|
inputs=[base_prompt_part1_input, base_prompt_color_input, base_prompt_clothing_input, design_prompt_input], |
|
outputs=output_image, |
|
title="Design and Paste on Clothing", |
|
description="Generate a base clothing image from the given prompts and paste the generated design onto it." |
|
).launch() |
|
|