forSubAnony's picture
ade20k
1cae162
import torch
import os
from PIL import Image
import numpy as np
from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
from diffusion_module.utils.Pipline import SDMLDMPipeline
def log_validation(vae, unet, noise_scheduler, accelerator, weight_dtype, data_ld,
resolution=512,g_step=2,save_dir="cityspace_test"):
scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config)
pipeline = SDMLDMPipeline(
vae=accelerator.unwrap_model(vae),
unet=accelerator.unwrap_model(unet),
scheduler=scheduler,
torch_dtype=weight_dtype,
resolution = resolution,
resolution_type="crack"
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=False)
pipeline.enable_xformers_memory_efficient_attention()
generator = None
for i ,batch in enumerate(data_ld):
if i > 2:
break
images = []
with torch.autocast("cuda"):
segmap = preprocess_input(batch[1]['label'], num_classes=151)
segmap = segmap.to("cuda").to(torch.float16)
# 暂时删除这个因为不想写绘图的函数,种类多太麻烦了
# segmap_clr = batch[1]['label_ori'][0].permute(0, 3, 1, 2) / 255.
image = pipeline(segmap=segmap[0][None,:], generator=generator,batch_size = 1,
num_inference_steps=50, s=1.5).images
#segmap_clr = segmap_clr.cpu()
#segmap_clr = segmap_clr[0].permute(1, 2, 0).numpy()
#segmap_clr = (segmap_clr * 255).astype('uint8')
# pil_image = Image.fromarray(segmap_clr)
# images.append(pil_image)
#print(image)
#image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0]
images.extend(image)
merge_images(images, i,accelerator,g_step)
del pipeline
torch.cuda.empty_cache()
def merge_images(images, val_step,accelerator,step):
for k, image in enumerate(images):
"""
if k == 0:
filename = "{}_condition.png".format(val_step)
else:
filename = "{}_{}.png".format(val_step, k)
"""
filename = "{}_{}.png".format(val_step, k)
# 更新的路径,包含'singles'文件夹
path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "singles", filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
image.save(path)
# 创建一个新的画板来合并所有图像
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)
combined_image = Image.new('RGB', (total_width, max_height))
# 粘贴每张图像到画板上
x_offset = 0
for img in images:
# 转换灰度图像为RGB
if img.mode != 'RGB':
img = img.convert('RGB')
combined_image.paste(img, (x_offset, 0))
x_offset += img.width
# 保存合并后的图像,路径包含'merges'文件夹
merge_filename = "{}_merge.png".format(val_step)
merge_path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "merges", merge_filename)
os.makedirs(os.path.split(merge_path)[0], exist_ok=True)
combined_image.save(merge_path)
def preprocess_input(data, num_classes):
# move to GPU and change data types
data = data.to(dtype=torch.int64)
# create one-hot label map
label_map = data
bs, _, h, w = label_map.size()
input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device)
input_semantics = input_label.scatter_(1, label_map, 1.0)
return input_semantics