lmattingly13
revert to version that runs locally
e154444
raw
history blame contribute delete
No virus
3.1 kB
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2
title = "ControlNet for Cartoon-ifying"
description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]
# Constants
low_threshold = 100
high_threshold = 200
base_model_path = "runwayml/stable-diffusion-v1-5"
controlnet_path = "lmattingly/controlnet-uncanny-simpsons-v2-0"
#controlnet_path = "JFoz/dog-cat-pose"
# Models
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
controlnet_path, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)
def canny_filter(image):
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
blurred_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
edges_image = cv2.Canny(blurred_image, 50, 150)
canny_image = Image.fromarray(edges_image)
return canny_image
def canny_filter2(image):
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def resize_image(im, max_size):
im_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
height, width = im_np.shape[:2]
scale_factor = max_size / max(height, width)
resized_np = cv2.resize(im_np, (int(width * scale_factor), int(height * scale_factor)))
resized_im = Image.fromarray(resized_np)
return resized_im
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def infer(prompts, image):
params["controlnet"] = controlnet_params
im = image
image = canny_filter2(im)
#image = canny_filter(im)
#image = Image.fromarray(im)
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([image] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
output = pipe(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
num_inference_steps=5,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
gr.Interface(fn = infer, inputs = ["text", "image"], outputs = "gallery",
title = title, description = description, theme='gradio/soft',
examples=[["a simpsons cartoon character", "simpsons_human_1.jpg"]]
).launch()