sd3m / fn.py
aka7774's picture
Upload 6 files
4df5742 verified
raw
history blame contribute delete
No virus
2.2 kB
import os
import io
import base64
import torch
from torch.cuda import amp
import numpy as np
from PIL import Image
from diffusers import AutoPipelineForText2Image, AutoencoderKL, DPMSolverMultistepScheduler
from diffusers import StableDiffusion3Pipeline
pipe = None
def load_model(_model = None, _vae = None, loras = []):
global pipe
_model = _model or "v2ray/stable-diffusion-3-medium-diffusers"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
kwargs = {}
if _vae:
# "stabilityai/sdxl-vae"
vae = AutoencoderKL.from_pretrained(_vae, torch_dtype=torch_dtype)
kwargs['vae'] = vae
pipe = StableDiffusion3Pipeline.from_pretrained(_model, torch_dtype=torch_dtype, **kwargs)
# DPM++ 2M Karras
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(
# pipe.scheduler.config,
# algorithm_type="sde-dpmsolver++",
# use_karras_sigmas=True
# )
for lora in loras:
pipe.load_lora_weights(".", weight_name=lora + ".safetensors")
if torch.cuda.is_available():
pipe = pipe.to("cuda")
#pipe.enable_vae_slicing()
def pil_to_webp(img):
buffer = io.BytesIO()
img.save(buffer, 'webp')
return buffer.getvalue()
def bin_to_base64(bin):
return base64.b64encode(bin).decode('ascii')
def run(prompt = None, negative_prompt = None, model = None, guidance_scale = None, steps = None, seed = None):
global pipe
if not pipe:
load_model(model)
_prompt = "A cat holding a sign that says hello world"
_negative_prompt = ""
prompt = prompt or _prompt
negative_prompt = negative_prompt or _negative_prompt
guidance_scale = float(guidance_scale) if guidance_scale else 7.0
steps = int(steps) if steps else 28
seed = int(seed) if seed else -1
generator = None
if seed != -1:
generator = torch.manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
clip_skip=2,
generator=generator,
).images[0]
return image