import gradio as gr from PIL import Image import torch from transformers import pipeline from torchvision import models, transforms # Load the models text_to_image_pipeline = pipeline("text-to-image-generation", model="CompVis/stable-diffusion-v1-4") segmentation_model = models.segmentation.deeplabv3_resnet101(pretrained=True) segmentation_model.eval() # Define transformation for the segmentation model preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Helper function to segment clothing area def segment_clothing(image): input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0) with torch.no_grad(): output = segmentation_model(input_batch)['out'][0] output_predictions = output.argmax(0) mask = output_predictions.byte().cpu().numpy() return mask # Function to generate base image def generate_base_image(base_prompt_part1, base_prompt_color, base_prompt_clothing): # Combine the parts to create the full base prompt base_prompt = f"{base_prompt_part1} {base_prompt_color} {base_prompt_clothing}" # Generate base clothing image base_image = text_to_image_pipeline(base_prompt)[0] base_image = Image.fromarray(base_image) return base_image # Define the function to generate design and paste it on the clothing def generate_and_paste_design(base_image, design_prompt): # Generate design generated_image = text_to_image_pipeline(design_prompt)[0] generated_design = Image.fromarray(generated_image) # Segment the clothing area clothing_mask = segment_clothing(base_image) # Ensure the generated design fits within the clothing area generated_design = generated_design.resize(base_image.size) # Paste the design onto the clothing area clothing_area = Image.composite(generated_design, base_image, Image.fromarray(clothing_mask*255)) return clothing_area # Create the Gradio interface base_prompt_part1_input = gr.inputs.Textbox(lines=1, placeholder="Enter 'a single plain'") base_prompt_color_input = gr.inputs.Textbox(lines=1, placeholder="Enter color type") base_prompt_clothing_input = gr.inputs.Textbox(lines=1, placeholder="Enter clothing type") design_prompt_input = gr.inputs.Textbox(lines=1, placeholder="Enter design prompt") output_image = gr.outputs.Image(type="pil") def full_process(base_prompt_part1, base_prompt_color, base_prompt_clothing, design_prompt): # Generate the base image base_image = generate_base_image(base_prompt_part1, base_prompt_color, base_prompt_clothing) # Generate and paste the design on the base image final_image = generate_and_paste_design(base_image, design_prompt) return final_image gr.Interface( fn=full_process, inputs=[base_prompt_part1_input, base_prompt_color_input, base_prompt_clothing_input, design_prompt_input], outputs=output_image, title="Design and Paste on Clothing", description="Generate a base clothing image from the given prompts and paste the generated design onto it." ).launch()