import torch import os from PIL import Image import numpy as np from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler from diffusion_module.utils.Pipline import SDMLDMPipeline def log_validation(vae, unet, noise_scheduler, accelerator, weight_dtype, data_ld, resolution=512,g_step=2,save_dir="cityspace_test"): scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config) pipeline = SDMLDMPipeline( vae=accelerator.unwrap_model(vae), unet=accelerator.unwrap_model(unet), scheduler=scheduler, torch_dtype=weight_dtype, resolution = resolution, resolution_type="crack" ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=False) pipeline.enable_xformers_memory_efficient_attention() generator = None for i ,batch in enumerate(data_ld): if i > 2: break images = [] with torch.autocast("cuda"): segmap = preprocess_input(batch[1]['label'], num_classes=151) segmap = segmap.to("cuda").to(torch.float16) # 暂时删除这个因为不想写绘图的函数,种类多太麻烦了 # segmap_clr = batch[1]['label_ori'][0].permute(0, 3, 1, 2) / 255. image = pipeline(segmap=segmap[0][None,:], generator=generator,batch_size = 1, num_inference_steps=50, s=1.5).images #segmap_clr = segmap_clr.cpu() #segmap_clr = segmap_clr[0].permute(1, 2, 0).numpy() #segmap_clr = (segmap_clr * 255).astype('uint8') # pil_image = Image.fromarray(segmap_clr) # images.append(pil_image) #print(image) #image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0] images.extend(image) merge_images(images, i,accelerator,g_step) del pipeline torch.cuda.empty_cache() def merge_images(images, val_step,accelerator,step): for k, image in enumerate(images): """ if k == 0: filename = "{}_condition.png".format(val_step) else: filename = "{}_{}.png".format(val_step, k) """ filename = "{}_{}.png".format(val_step, k) # 更新的路径,包含'singles'文件夹 path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "singles", filename) os.makedirs(os.path.split(path)[0], exist_ok=True) image.save(path) # 创建一个新的画板来合并所有图像 total_width = sum(img.width for img in images) max_height = max(img.height for img in images) combined_image = Image.new('RGB', (total_width, max_height)) # 粘贴每张图像到画板上 x_offset = 0 for img in images: # 转换灰度图像为RGB if img.mode != 'RGB': img = img.convert('RGB') combined_image.paste(img, (x_offset, 0)) x_offset += img.width # 保存合并后的图像,路径包含'merges'文件夹 merge_filename = "{}_merge.png".format(val_step) merge_path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "merges", merge_filename) os.makedirs(os.path.split(merge_path)[0], exist_ok=True) combined_image.save(merge_path) def preprocess_input(data, num_classes): # move to GPU and change data types data = data.to(dtype=torch.int64) # create one-hot label map label_map = data bs, _, h, w = label_map.size() input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device) input_semantics = input_label.scatter_(1, label_map, 1.0) return input_semantics