# Dependencies """pip install torch pillow requests diffusers imageio gradio==3.4 httpx==0.23.2 transformers accelerate""" import gradio as gr import imageio import torch import numpy as np from PIL import Image from diffusers import StableDiffusionInpaintPipeline def perform_inpainting(prompt): # save_images() # Ensure CPU inference img_path = "Original Image.png" mask_path= "Mask Image.png" device = "cuda" model_name="runwayml/stable-diffusion-v1-5" torch_dtype = torch.float16 # Create the inpainting pipeline pipeline = create_inpaint_pipeline(model_name) pipeline = pipeline.to(device) # Explicitly move model to CPU # Load and pre-process images try: init_image = Image.open(img_path).convert("RGB").resize((512, 512)) mask_image = Image.open(mask_path).convert("RGB").resize((512, 512)) except FileNotFoundError: print(f"Error: Image files '{img_path}' or '{mask_path}' not found.") return None print("Processing the image...") # Perform inpainting try: image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] image.save("Inpainted_img.png") return image except Exception as e: print(f"Error during inpainting: {e}") return None def create_inpaint_pipeline(model_name): pipeline = StableDiffusionInpaintPipeline.from_pretrained( model_name, torch_dtype=torch.float16, ) return pipeline # if __name__ == "__main__": # generated_image = perform_inpainting() # if generated_image is not None: # generated_image.show() # # Optionally save the generated image # generated_image.save("inpainted_image.png") def Mask(img): """ Function to process the input image and generate a mask. Args: img (dict): Dictionary containing the base image and the mask image. Returns: tuple: A tuple containing the base image and the mask image. """ try: # Save the mask image to a file imageio.imwrite("Original Image.png",img["image"]) imageio.imwrite("Mask Image.png", img["mask"]) return img["image"], img["mask"] except KeyError as e: # Handle case where expected keys are not in the input dictionary return f"Key error: {e}", None except Exception as e: # Handle any other unexpected errors return f"An error occurred: {e}", None def main(): # Create the Gradio interface with gr.Blocks() as demo: with gr.Row(): img = gr.Image(tool="sketch", label="Paint Image", show_label=True) img1 = gr.Image(label="Original Image") img2 = gr.Image(label="Mask Image", show_label=True) btn = gr.Button() # Set the button click action btn.click(Mask, inputs=img, outputs=[img1, img2]) # with gr.Blocks(): with gr.Row(): prompt = gr.Textbox(label="Enter the prompt") button = gr.Button("Click") output_image = gr.Image(label="Generated Image") button.click(perform_inpainting, inputs=prompt,outputs=output_image) # Launch the Gradio interface demo.launch() if __name__=='__main__': main()