NewModel / app.py
gaur3009's picture
Create app.py
a3d4537 verified
raw
history blame
No virus
3.19 kB
import gradio as gr
from PIL import Image
import torch
from transformers import pipeline
from torchvision import models, transforms
# Load the models
text_to_image_pipeline = pipeline("text-to-image-generation", model="CompVis/stable-diffusion-v1-4")
segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True)
segmentation_model.eval()
# Define transformation for the segmentation model
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Helper function to segment clothing area
def segment_clothing(image):
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = segmentation_model(input_batch)['out'][0]
output_predictions = output.argmax(0)
mask = output_predictions.byte().cpu().numpy()
return mask
# Function to generate base image
def generate_base_image(base_prompt_part1, base_prompt_color, base_prompt_clothing):
# Combine the parts to create the full base prompt
base_prompt = f"{base_prompt_part1} {base_prompt_color} {base_prompt_clothing}"
# Generate base clothing image
base_image = text_to_image_pipeline(base_prompt)[0]
base_image = Image.fromarray(base_image)
return base_image
# Define the function to generate design and paste it on the clothing
def generate_and_paste_design(base_image, design_prompt):
# Generate design
generated_image = text_to_image_pipeline(design_prompt)[0]
generated_design = Image.fromarray(generated_image)
# Segment the clothing area
clothing_mask = segment_clothing(base_image)
# Ensure the generated design fits within the clothing area
generated_design = generated_design.resize(base_image.size)
# Paste the design onto the clothing area
clothing_area = Image.composite(generated_design, base_image, Image.fromarray(clothing_mask*255))
return clothing_area
# Create the Gradio interface
base_prompt_part1_input = gr.inputs.Textbox(lines=1, placeholder="Enter 'a single plain'")
base_prompt_color_input = gr.inputs.Textbox(lines=1, placeholder="Enter color type")
base_prompt_clothing_input = gr.inputs.Textbox(lines=1, placeholder="Enter clothing type")
design_prompt_input = gr.inputs.Textbox(lines=1, placeholder="Enter design prompt")
output_image = gr.outputs.Image(type="pil")
def full_process(base_prompt_part1, base_prompt_color, base_prompt_clothing, design_prompt):
# Generate the base image
base_image = generate_base_image(base_prompt_part1, base_prompt_color, base_prompt_clothing)
# Generate and paste the design on the base image
final_image = generate_and_paste_design(base_image, design_prompt)
return final_image
gr.Interface(
fn=full_process,
inputs=[base_prompt_part1_input, base_prompt_color_input, base_prompt_clothing_input, design_prompt_input],
outputs=output_image,
title="Design and Paste on Clothing",
description="Generate a base clothing image from the given prompts and paste the generated design onto it."
).launch()