mangaka / app.py
parsee-mizuhashi's picture
Update app.py
7222f5b verified
raw
history blame contribute delete
No virus
5.56 kB
import torch
from PIL import Image, ImageOps, ImageSequence
import numpy as np
import comfy.sample
import comfy.sd
def vencode(vae, pth):
pilimg = pth
pixels = np.array(pilimg).astype(np.float32) / 255.0
pixels = torch.from_numpy(pixels)[None,]
t = vae.encode(pixels[:,:,:,:3])
return {"samples":t}
from pathlib import Path
if not Path("model.safetensors").exists():
import requests
with open("model.safetensors", "wb") as f:
f.write(requests.get("https://huggingface.co/parsee-mizuhashi/mangaka/resolve/main/mangaka.safetensors?download=true").content)
MODEL_FILE = "model.safetensors"
with torch.no_grad():
unet, clip, vae = comfy.sd.load_checkpoint_guess_config(MODEL_FILE, output_vae=True, output_clip=True)[:3]# :3
BASE_NEG = "(low-quality worst-quality:1.4 (bad-anatomy (inaccurate-limb:1.2 bad-composition inaccurate-eyes extra-digit fewer-digits (extra-arms:1.2)"
DEVICE = "cpu" if not torch.cuda.is_available() else "cuda"
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0):
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
latnt = latent["samples"]
noise = comfy.sample.prepare_noise(latnt, seed, None)
disable_pbar = True
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latnt,
denoise=denoise, noise_mask=noise_mask, disable_pbar=disable_pbar, seed=seed)
out = samples
return out
def set_mask(samples, mask):
s = samples.copy()
s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
return s
def load_image_mask(image):
image_path = image
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = "A"
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return mask.unsqueeze(0)
@torch.no_grad()
def main(img, variant, positive, negative, pilimg):
variant = min(int(variant), limits[img])
global unet, clip, vae
mask = load_image_mask(f"./mangaka-d/{img}/i{variant}.png")
tkns = clip.tokenize("(greyscale monochrome black-and-white:1.3)" + positive)
cond, c = clip.encode_from_tokens(tkns, return_pooled=True)
uncond_tkns = clip.tokenize(BASE_NEG + negative)
uncond, uc = clip.encode_from_tokens(uncond_tkns, return_pooled=True)
cn = [[cond, {"pooled_output": c}]]
un = [[uncond, {"pooled_output": uc}]]
latent = vencode(vae, pilimg)
latent = set_mask(latent, mask)
denoised = common_ksampler(unet, 0, 20, 7, 'ddpm', 'karras', cn, un, latent, denoise=1)
decoded = vae.decode(denoised)
i = 255. * decoded[0].cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
return img
limits = {
"1": 4,
"2": 4,
"3": 5,
"4": 6,
"5": 4,
"6": 6,
"7": 8,
"8": 5,
"9": 5,
"s1": 4,
"s2": 6,
"s3": 5,
"s4": 5,
"s5": 4,
"s6": 4
}
import gradio as gr
def visualize_fn(page, panel):
base = f"./mangaka-d/{page}/base.png"
base = Image.open(base)
if panel == "none":
return base
panel = min(int(panel), limits[page])
mask = f"./mangaka-d/{page}/i{panel}.png"
base = base.convert("RGBA")
mask = Image.open(mask)
#remove all green and blue from the mask
mask = mask.convert("RGBA")
data = mask.getdata()
data = [
(255, 0, 0, 255) if pixel[:3] == (255, 255, 255) else pixel
for pixel in mask.getdata()
]
mask.putdata(data)
#overlay the mask on the base
base.paste(mask, (0,0), mask)
return base
def reset_fn(page):
base = f"./mangaka-d/{page}/base.png"
base = Image.open(base)
return base
with gr.Blocks() as demo:
with gr.Tab("Mangaka"):
with gr.Row():
with gr.Column():
positive = gr.Textbox(label="Positive prompt", lines=2)
negative = gr.Textbox(label="Negative prompt")
with gr.Accordion("Page Settings"):
with gr.Row():
with gr.Column():
page = gr.Dropdown(label="Page", choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "s1", "s2", "s3", "s4", "s5", "s6"], value="s1")
panel = gr.Dropdown(label="Panel", choices=["1", "2", "3", "4", "5", "6", "7", "8", "none"], value="1")
visualize = gr.Button("Visualize")
with gr.Column():
visualize_output = gr.Image(interactive=False)
visualize.click(visualize_fn, inputs=[page, panel], outputs=visualize_output)
with gr.Column():
with gr.Row():
with gr.Column():
generate = gr.Button("Generate", variant="primary")
with gr.Column():
reset = gr.Button("Reset", variant="stop")
current_panel = gr.Image(interactive=False)
reset.click(reset_fn, inputs=[page], outputs=current_panel)
generate.click(main, inputs=[page, panel, positive, negative, current_panel], outputs=current_panel)
demo.launch()