ai-filters / main.py
tejavardhan's picture
Upload main.py
a6c15c9 verified
raw
history blame
No virus
3.34 kB
# Dependencies
"""pip install torch pillow requests diffusers imageio gradio==3.4 httpx==0.23.2 transformers accelerate"""
import gradio as gr
import imageio
import torch
import numpy as np
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline
def perform_inpainting(prompt):
# save_images()
# Ensure CPU inference
img_path = "Original Image.png"
mask_path= "Mask Image.png"
device = "cuda"
model_name="runwayml/stable-diffusion-v1-5"
torch_dtype = torch.float16
# Create the inpainting pipeline
pipeline = create_inpaint_pipeline(model_name)
pipeline = pipeline.to(device) # Explicitly move model to CPU
# Load and pre-process images
try:
init_image = Image.open(img_path).convert("RGB").resize((512, 512))
mask_image = Image.open(mask_path).convert("RGB").resize((512, 512))
except FileNotFoundError:
print(f"Error: Image files '{img_path}' or '{mask_path}' not found.")
return None
print("Processing the image...")
# Perform inpainting
try:
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
image.save("Inpainted_img.png")
return image
except Exception as e:
print(f"Error during inpainting: {e}")
return None
def create_inpaint_pipeline(model_name):
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
)
return pipeline
# if __name__ == "__main__":
# generated_image = perform_inpainting()
# if generated_image is not None:
# generated_image.show()
# # Optionally save the generated image
# generated_image.save("inpainted_image.png")
def Mask(img):
"""
Function to process the input image and generate a mask.
Args:
img (dict): Dictionary containing the base image and the mask image.
Returns:
tuple: A tuple containing the base image and the mask image.
"""
try:
# Save the mask image to a file
imageio.imwrite("Original Image.png",img["image"])
imageio.imwrite("Mask Image.png", img["mask"])
return img["image"], img["mask"]
except KeyError as e:
# Handle case where expected keys are not in the input dictionary
return f"Key error: {e}", None
except Exception as e:
# Handle any other unexpected errors
return f"An error occurred: {e}", None
def main():
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
img = gr.Image(tool="sketch", label="Paint Image", show_label=True)
img1 = gr.Image(label="Original Image")
img2 = gr.Image(label="Mask Image", show_label=True)
btn = gr.Button()
# Set the button click action
btn.click(Mask, inputs=img, outputs=[img1, img2])
# with gr.Blocks():
with gr.Row():
prompt = gr.Textbox(label="Enter the prompt")
button = gr.Button("Click")
output_image = gr.Image(label="Generated Image")
button.click(perform_inpainting, inputs=prompt,outputs=output_image)
# Launch the Gradio interface
demo.launch()
if __name__=='__main__':
main()