roubaofeipi commited on
Commit
626a5d4
1 Parent(s): 7d05f9e

Update train/train_t2i.py

Browse files
Files changed (1) hide show
  1. train/train_t2i.py +806 -807
train/train_t2i.py CHANGED
@@ -1,807 +1,806 @@
1
- import torch
2
- import json
3
- import yaml
4
- import torchvision
5
- from torch import nn, optim
6
- from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
7
- from warmup_scheduler import GradualWarmupScheduler
8
- import torch.multiprocessing as mp
9
- import numpy as np
10
- import os
11
- import sys
12
- sys.path.append(os.path.abspath('./'))
13
- from dataclasses import dataclass
14
- from torch.distributed import init_process_group, destroy_process_group, barrier
15
- from gdf import GDF_dual_fixlrt as GDF
16
- from gdf import EpsilonTarget, CosineSchedule
17
- from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
18
- from torchtools.transforms import SmartCrop
19
- from fractions import Fraction
20
- from modules.effnet import EfficientNetEncoder
21
-
22
- from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
23
- from modules.previewer import Previewer
24
- from core.data import Bucketeer
25
- from train.base import DataCore, TrainingCore
26
- from tqdm import tqdm
27
- from core import WarpCore
28
- from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
29
-
30
- from accelerate import init_empty_weights
31
- from accelerate.utils import set_module_tensor_to_device
32
- from contextlib import contextmanager
33
- from train.dist_core import *
34
- import glob
35
- from torch.utils.data import DataLoader, Dataset
36
- from torch.nn.parallel import DistributedDataParallel as DDP
37
- from torch.utils.data.distributed import DistributedSampler
38
- from PIL import Image
39
- from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
40
- from core.utils import Base
41
- from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm
42
- import torch.nn.functional as F
43
- import functools
44
- import math
45
- import copy
46
- import random
47
- from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
48
- Image.MAX_IMAGE_PIXELS = None
49
- torch.manual_seed(23)
50
- random.seed(23)
51
- np.random.seed(23)
52
- #7978026
53
-
54
- class Null_Model(torch.nn.Module):
55
- def __init__(self):
56
- super().__init__()
57
- def forward(self, x):
58
- pass
59
-
60
-
61
-
62
-
63
- def identity(x):
64
- if isinstance(x, bytes):
65
- x = x.decode('utf-8')
66
- return x
67
- def check_nan_inmodel(model, meta=''):
68
- for name, param in model.named_parameters():
69
- if torch.isnan(param).any():
70
- print(f"nan detected in {name}", meta)
71
- return True
72
- print('no nan', meta)
73
- return False
74
- class mydist_dataset(Dataset):
75
- def __init__(self, rootpath, img_processor=None):
76
-
77
- self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg'))
78
- self.img_processor = img_processor
79
- self.length = len( self.img_pathlist)
80
-
81
-
82
-
83
- def __getitem__(self, idx):
84
-
85
- imgpath = self.img_pathlist[idx]
86
- json_file = imgpath.replace('.jpg', '.json')
87
-
88
- with open(json_file, 'r') as file:
89
- info = json.load(file)
90
- txt = info['caption']
91
- if txt is None:
92
- txt = ' '
93
- try:
94
- img = Image.open(imgpath).convert('RGB')
95
- w, h = img.size
96
- if self.img_processor is not None:
97
- img = self.img_processor(img)
98
-
99
- except:
100
- print('exception', imgpath)
101
- return self.__getitem__(random.randint(0, self.length -1 ) )
102
- return dict(captions=txt, images=img)
103
- def __len__(self):
104
- return self.length
105
-
106
- class WurstCore(TrainingCore, DataCore, WarpCore):
107
- @dataclass(frozen=True)
108
- class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
109
- # TRAINING PARAMS
110
- lr: float = EXPECTED_TRAIN
111
- warmup_updates: int = EXPECTED_TRAIN
112
- dtype: str = None
113
-
114
- # MODEL VERSION
115
- model_version: str = EXPECTED # 3.6B or 1B
116
- clip_image_model_name: str = 'openai/clip-vit-large-patch14'
117
- clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
118
-
119
- # CHECKPOINT PATHS
120
- effnet_checkpoint_path: str = EXPECTED
121
- previewer_checkpoint_path: str = EXPECTED
122
-
123
- generator_checkpoint_path: str = None
124
-
125
- # gdf customization
126
- adaptive_loss_weight: str = None
127
- use_ddp: bool=EXPECTED
128
-
129
-
130
- @dataclass(frozen=True)
131
- class Data(Base):
132
- dataset: Dataset = EXPECTED
133
- dataloader: DataLoader = EXPECTED
134
- iterator: any = EXPECTED
135
- sampler: DistributedSampler = EXPECTED
136
-
137
- @dataclass(frozen=True)
138
- class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
139
- effnet: nn.Module = EXPECTED
140
- previewer: nn.Module = EXPECTED
141
- train_norm: nn.Module = EXPECTED
142
-
143
-
144
- @dataclass(frozen=True)
145
- class Schedulers(WarpCore.Schedulers):
146
- generator: any = None
147
-
148
- @dataclass(frozen=True)
149
- class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
150
- gdf: GDF = EXPECTED
151
- sampling_configs: dict = EXPECTED
152
- effnet_preprocess: torchvision.transforms.Compose = EXPECTED
153
-
154
- info: TrainingCore.Info
155
- config: Config
156
-
157
- def setup_extras_pre(self) -> Extras:
158
- gdf = GDF(
159
- schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
160
- input_scaler=VPScaler(), target=EpsilonTarget(),
161
- noise_cond=CosineTNoiseCond(),
162
- loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
163
- )
164
- sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
165
-
166
- if self.info.adaptive_loss is not None:
167
- gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
168
- gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
169
-
170
- effnet_preprocess = torchvision.transforms.Compose([
171
- torchvision.transforms.Normalize(
172
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
173
- )
174
- ])
175
-
176
- clip_preprocess = torchvision.transforms.Compose([
177
- torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
178
- torchvision.transforms.CenterCrop(224),
179
- torchvision.transforms.Normalize(
180
- mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
181
- )
182
- ])
183
-
184
- if self.config.training:
185
- transforms = torchvision.transforms.Compose([
186
- torchvision.transforms.ToTensor(),
187
- torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
188
- SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
189
- ])
190
- else:
191
- transforms = None
192
-
193
- return self.Extras(
194
- gdf=gdf,
195
- sampling_configs=sampling_configs,
196
- transforms=transforms,
197
- effnet_preprocess=effnet_preprocess,
198
- clip_preprocess=clip_preprocess
199
- )
200
-
201
- def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
202
- eval_image_embeds=False, return_fields=None):
203
- conditions = super().get_conditions(
204
- batch, models, extras, is_eval, is_unconditional,
205
- eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
206
- )
207
- return conditions
208
-
209
- def setup_models(self, extras: Extras) -> Models: # configure model
210
-
211
- dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
212
-
213
- # EfficientNet encoderin
214
- effnet = EfficientNetEncoder()
215
- effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
216
- effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
217
- effnet.eval().requires_grad_(False).to(self.device)
218
- del effnet_checkpoint
219
-
220
- # Previewer
221
- previewer = Previewer()
222
- previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
223
- previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
224
- previewer.eval().requires_grad_(False).to(self.device)
225
- del previewer_checkpoint
226
-
227
- @contextmanager
228
- def dummy_context():
229
- yield None
230
-
231
- loading_context = dummy_context if self.config.training else init_empty_weights
232
-
233
- # Diffusion models
234
- with loading_context():
235
- generator_ema = None
236
- if self.config.model_version == '3.6B':
237
- generator = StageC()
238
- if self.config.ema_start_iters is not None: # default setting
239
- generator_ema = StageC()
240
- elif self.config.model_version == '1B':
241
- print('in line 155 1b light model', self.config.model_version )
242
- generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
243
-
244
- if self.config.ema_start_iters is not None and self.config.training:
245
- generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
246
- else:
247
- raise ValueError(f"Unknown model version {self.config.model_version}")
248
-
249
-
250
-
251
- if loading_context is dummy_context:
252
- generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
253
- else:
254
- for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
255
- set_module_tensor_to_device(generator, param_name, "cpu", value=param)
256
-
257
- generator._init_extra_parameter()
258
- generator = generator.to(torch.bfloat16).to(self.device)
259
-
260
-
261
- train_norm = nn.ModuleList()
262
- cnt_norm = 0
263
- for mm in generator.modules():
264
- if isinstance(mm, GlobalResponseNorm):
265
-
266
- train_norm.append(Null_Model())
267
- cnt_norm += 1
268
-
269
- train_norm.append(generator.agg_net)
270
- train_norm.append(generator.agg_net_up)
271
- total = sum([ param.nelement() for param in train_norm.parameters()])
272
- print('Trainable parameter', total / 1048576)
273
-
274
- if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
275
- sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
276
- collect_sd = {}
277
- for k, v in sdd.items():
278
- collect_sd[k[7:]] = v
279
- train_norm.load_state_dict(collect_sd, strict=True)
280
-
281
-
282
- train_norm.to(self.device).train().requires_grad_(True)
283
- train_norm_ema = copy.deepcopy(train_norm)
284
- train_norm_ema.to(self.device).eval().requires_grad_(False)
285
- if generator_ema is not None:
286
-
287
- generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
288
- generator_ema._init_extra_parameter()
289
-
290
-
291
- pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
292
- if os.path.exists(pretrained_pth):
293
- print(pretrained_pth, 'exists')
294
- generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
295
-
296
-
297
- generator_ema.eval().requires_grad_(False)
298
-
299
-
300
-
301
-
302
- check_nan_inmodel(generator, 'generator')
303
-
304
-
305
-
306
- if self.config.use_ddp and self.config.training:
307
-
308
- train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
309
-
310
- # CLIP encoders
311
- tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
312
- text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
313
- image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
314
-
315
- return self.Models(
316
- effnet=effnet, previewer=previewer, train_norm = train_norm,
317
- generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model,
318
- )
319
-
320
- def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
321
-
322
-
323
- params = []
324
- params += list(models.train_norm.module.parameters())
325
-
326
- optimizer = optim.AdamW(params, lr=self.config.lr)
327
-
328
- return self.Optimizers(generator=optimizer)
329
-
330
- def ema_update(self, ema_model, source_model, beta):
331
- for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
332
- param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
333
-
334
- def sync_ema(self, ema_model):
335
- for param in ema_model.parameters():
336
- torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
337
- param.data /= torch.distributed.get_world_size()
338
- def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
339
-
340
-
341
- optimizer = optim.AdamW(
342
- models.generator.up_blocks.parameters() ,
343
- lr=self.config.lr)
344
- optimizer = self.load_optimizer(optimizer, 'generator_optim',
345
- fsdp_model=models.generator if self.config.use_fsdp else None)
346
- return self.Optimizers(generator=optimizer)
347
-
348
- def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
349
- scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
350
- scheduler.last_epoch = self.info.total_steps
351
- return self.Schedulers(generator=scheduler)
352
-
353
- def setup_data(self, extras: Extras) -> WarpCore.Data:
354
- # SETUP DATASET
355
- dataset_path = self.config.webdataset_path
356
- dataset = mydist_dataset(dataset_path, \
357
- torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
358
- else extras.transforms)
359
-
360
- # SETUP DATALOADER
361
- real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
362
-
363
- sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
364
- dataloader = DataLoader(
365
- dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
366
- collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
367
- sampler = sampler
368
- )
369
- if self.is_main_node:
370
- print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
371
-
372
- if self.config.multi_aspect_ratio is not None:
373
- aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
374
- dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
375
- ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
376
- interpolate_nearest=False) # , use_smartcrop=True)
377
- else:
378
-
379
- dataloader_iterator = iter(dataloader)
380
-
381
- return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
382
-
383
-
384
- def models_to_save(self):
385
- pass
386
- def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
387
-
388
- if not single_gpu:
389
- local_rank = rank
390
- process_id = rank
391
- world_size = get_world_size()
392
-
393
- self.process_id = process_id
394
- self.is_main_node = process_id == 0
395
- self.device = torch.device(local_rank)
396
- self.world_size = world_size
397
-
398
- os.environ['MASTER_ADDR'] = 'localhost'
399
- os.environ['MASTER_PORT'] = '41443'
400
- torch.cuda.set_device(local_rank)
401
- init_process_group(
402
- backend="nccl",
403
- rank=local_rank,
404
- world_size=world_size,
405
- )
406
- print(f"[GPU {process_id}] READY")
407
- else:
408
- self.is_main_node = rank == 0
409
- self.process_id = rank
410
- self.device = torch.device('cuda:0')
411
- self.world_size = 1
412
- print("Running in single thread, DDP not enabled.")
413
- # Training loop --------------------------------
414
- def get_target_lr_size(self, ratio, std_size=24):
415
- w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
416
- return (h * 32 , w * 32)
417
- def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
418
- #batch = next(data.iterator)
419
- batch = data
420
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
421
- shape_lr = self.get_target_lr_size(ratio)
422
- #print('in line 485', shape_lr, ratio, batch['images'].shape)
423
- with torch.no_grad():
424
- conditions = self.get_conditions(batch, models, extras)
425
-
426
- latents = self.encode_latents(batch, models, extras)
427
- latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
428
-
429
- noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
430
- noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
431
-
432
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
433
- # 768 1536
434
- require_cond = True
435
-
436
- with torch.no_grad():
437
- _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
438
-
439
-
440
- pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)
441
- loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
442
-
443
- loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
444
-
445
-
446
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
447
- extras.gdf.loss_weight.update_buckets(logSNR, loss)
448
-
449
- return loss, loss_adjusted
450
-
451
- def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
452
-
453
-
454
- if update:
455
-
456
- torch.distributed.barrier()
457
- loss_adjusted.backward()
458
-
459
- grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
460
-
461
- optimizers_dict = optimizers.to_dict()
462
- for k in optimizers_dict:
463
- if k != 'training':
464
- optimizers_dict[k].step()
465
- schedulers_dict = schedulers.to_dict()
466
- for k in schedulers_dict:
467
- if k != 'training':
468
- schedulers_dict[k].step()
469
- for k in optimizers_dict:
470
- if k != 'training':
471
- optimizers_dict[k].zero_grad(set_to_none=True)
472
- self.info.total_steps += 1
473
- else:
474
-
475
- loss_adjusted.backward()
476
-
477
- grad_norm = torch.tensor(0.0).to(self.device)
478
-
479
- return grad_norm
480
-
481
-
482
- def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
483
-
484
- images = batch['images'].to(self.device)
485
- if target_size is not None:
486
- images = F.interpolate(images, target_size)
487
-
488
- return models.effnet(extras.effnet_preprocess(images))
489
-
490
- def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
491
- return models.previewer(latents)
492
-
493
- def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
494
-
495
- self.is_main_node = (rank == 0)
496
- self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
497
- self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
498
- self.info: self.Info = self.setup_info()
499
-
500
-
501
-
502
- def __call__(self, single_gpu=False):
503
-
504
- if self.config.allow_tf32:
505
- torch.backends.cuda.matmul.allow_tf32 = True
506
- torch.backends.cudnn.allow_tf32 = True
507
-
508
- if self.is_main_node:
509
- print()
510
- print("**STARTIG JOB WITH CONFIG:**")
511
- print(yaml.dump(self.config.to_dict(), default_flow_style=False))
512
- print("------------------------------------")
513
- print()
514
- print("**INFO:**")
515
- print(yaml.dump(vars(self.info), default_flow_style=False))
516
- print("------------------------------------")
517
- print()
518
-
519
- # SETUP STUFF
520
- extras = self.setup_extras_pre()
521
- assert extras is not None, "setup_extras_pre() must return a DTO"
522
-
523
-
524
-
525
- data = self.setup_data(extras)
526
- assert data is not None, "setup_data() must return a DTO"
527
- if self.is_main_node:
528
- print("**DATA:**")
529
- print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
530
- print("------------------------------------")
531
- print()
532
-
533
- models = self.setup_models(extras)
534
- assert models is not None, "setup_models() must return a DTO"
535
- if self.is_main_node:
536
- print("**MODELS:**")
537
- print(yaml.dump({
538
- k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
539
- }, default_flow_style=False))
540
- print("------------------------------------")
541
- print()
542
-
543
-
544
-
545
- optimizers = self.setup_optimizers(extras, models)
546
- assert optimizers is not None, "setup_optimizers() must return a DTO"
547
- if self.is_main_node:
548
- print("**OPTIMIZERS:**")
549
- print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
550
- print("------------------------------------")
551
- print()
552
-
553
- schedulers = self.setup_schedulers(extras, models, optimizers)
554
- assert schedulers is not None, "setup_schedulers() must return a DTO"
555
- if self.is_main_node:
556
- print("**SCHEDULERS:**")
557
- print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
558
- print("------------------------------------")
559
- print()
560
-
561
- post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
562
- assert post_extras is not None, "setup_extras_post() must return a DTO"
563
- extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
564
- if self.is_main_node:
565
- print("**EXTRAS:**")
566
- print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
567
- print("------------------------------------")
568
- print()
569
- # -------
570
-
571
- # TRAIN
572
- if self.is_main_node:
573
- print("**TRAINING STARTING...**")
574
- self.train(data, extras, models, optimizers, schedulers)
575
-
576
- if single_gpu is False:
577
- barrier()
578
- destroy_process_group()
579
- if self.is_main_node:
580
- print()
581
- print("------------------------------------")
582
- print()
583
- print("**TRAINING COMPLETE**")
584
-
585
-
586
-
587
- def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
588
- schedulers: WarpCore.Schedulers):
589
- start_iter = self.info.iter + 1
590
- max_iters = self.config.updates * self.config.grad_accum_steps
591
- if self.is_main_node:
592
- print(f"STARTING AT STEP: {start_iter}/{max_iters}")
593
-
594
-
595
- if self.is_main_node:
596
- create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
597
-
598
- models.generator.train()
599
-
600
- iter_cnt = 0
601
- epoch_cnt = 0
602
- models.train_norm.train()
603
- while True:
604
- epoch_cnt += 1
605
- if self.world_size > 1:
606
-
607
- data.sampler.set_epoch(epoch_cnt)
608
- for ggg in range(len(data.dataloader)):
609
- iter_cnt += 1
610
- loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
611
- grad_norm = self.backward_pass(
612
- iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
613
- models, optimizers, schedulers
614
- )
615
-
616
- self.info.iter = iter_cnt
617
-
618
-
619
- # UPDATE LOSS METRICS
620
- self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
621
-
622
- #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
623
- if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
624
- print(f" NaN value encountered in training run {self.info.wandb_run_id}", \
625
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
626
-
627
- if self.is_main_node:
628
- logs = {
629
- 'loss': self.info.ema_loss,
630
- 'backward_loss': loss_adjusted.mean().item(),
631
- 'ema_loss': self.info.ema_loss,
632
- 'raw_ori_loss': loss.mean().item(),
633
- 'grad_norm': grad_norm.item(),
634
- 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
635
- 'total_steps': self.info.total_steps,
636
- }
637
- if iter_cnt % (self.config.save_every) == 0:
638
-
639
- print(iter_cnt, max_iters, logs, epoch_cnt, )
640
-
641
-
642
-
643
- if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
644
-
645
- # SAVE AND CHECKPOINT STUFF
646
- if np.isnan(loss.mean().item()):
647
- if self.is_main_node and self.config.wandb_project is not None:
648
- print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
649
- f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
650
-
651
- else:
652
- if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
653
- self.info.adaptive_loss = {
654
- 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
655
- 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
656
- }
657
-
658
-
659
-
660
- if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
661
- print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
662
- torch.save(models.train_norm.state_dict(), \
663
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
664
-
665
- torch.save(models.train_norm.state_dict(), \
666
- f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
667
-
668
-
669
- if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
670
-
671
- if self.is_main_node:
672
-
673
- self.sample(models, data, extras)
674
-
675
-
676
- if self.info.iter >= max_iters:
677
- break
678
-
679
- def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
680
-
681
-
682
- models.generator.eval()
683
- models.train_norm.eval()
684
- with torch.no_grad():
685
- batch = next(data.iterator)
686
- ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
687
-
688
- shape_lr = self.get_target_lr_size(ratio)
689
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
690
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
691
-
692
- latents = self.encode_latents(batch, models, extras)
693
- latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
694
-
695
-
696
- if self.is_main_node:
697
-
698
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
699
-
700
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
701
- models.generator, conditions,
702
- latents.shape, latents_lr.shape,
703
- unconditions, device=self.device, **extras.sampling_configs
704
- )
705
-
706
-
707
-
708
-
709
- if self.is_main_node:
710
- print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, )
711
- noised_images = torch.cat(
712
- [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
713
-
714
- sampled_images = torch.cat(
715
- [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
716
-
717
-
718
- noised_images_lr = torch.cat(
719
- [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
720
-
721
- sampled_images_lr = torch.cat(
722
- [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
723
-
724
- images = batch['images']
725
- if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
726
- images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
727
- images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
728
-
729
- collage_img = torch.cat([
730
- torch.cat([i for i in images.cpu()], dim=-1),
731
- torch.cat([i for i in noised_images.cpu()], dim=-1),
732
- torch.cat([i for i in sampled_images.cpu()], dim=-1),
733
- ], dim=-2)
734
-
735
- collage_img_lr = torch.cat([
736
- torch.cat([i for i in images_lr.cpu()], dim=-1),
737
- torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
738
- torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
739
- ], dim=-2)
740
-
741
- torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
742
- torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
743
-
744
-
745
- models.generator.train()
746
- models.train_norm.train()
747
- print('finish sampling')
748
-
749
-
750
-
751
- def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
752
-
753
-
754
- models.generator.eval()
755
-
756
- with torch.no_grad():
757
-
758
- if self.is_main_node:
759
- conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
760
- unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
761
-
762
- with torch.cuda.amp.autocast(dtype=torch.bfloat16):
763
-
764
- *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
765
- models.generator, conditions,
766
- hr_shape, lr_shape,
767
- unconditions, device=self.device, **extras.sampling_configs
768
- )
769
-
770
- if models.generator_ema is not None:
771
-
772
- *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
773
- models.generator_ema, conditions,
774
- latents.shape, latents_lr.shape,
775
- unconditions, device=self.device, **extras.sampling_configs
776
- )
777
-
778
- else:
779
- sampled_ema = sampled
780
- sampled_ema_lr = sampled_lr
781
-
782
- return sampled, sampled_lr
783
- def main_worker(rank, cfg):
784
- print("Launching Script in main worker")
785
-
786
- warpcore = WurstCore(
787
- config_file_path=cfg, rank=rank, world_size = get_world_size()
788
- )
789
- # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
790
-
791
- # RUN TRAINING
792
- warpcore(get_world_size()==1)
793
-
794
- if __name__ == '__main__':
795
- print('launch multi process')
796
- # os.environ["OMP_NUM_THREADS"] = "1"
797
- # os.environ["MKL_NUM_THREADS"] = "1"
798
- #dist.init_process_group(backend="nccl")
799
- #torch.backends.cudnn.benchmark = True
800
- #train/train_c_my.py
801
- #mp.set_sharing_strategy('file_system')
802
-
803
- if get_master_ip() == "127.0.0.1":
804
- # manually launch distributed processes
805
- mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
806
- else:
807
- main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
 
1
+ import torch
2
+ import json
3
+ import yaml
4
+ import torchvision
5
+ from torch import nn, optim
6
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
7
+ from warmup_scheduler import GradualWarmupScheduler
8
+ import torch.multiprocessing as mp
9
+ import numpy as np
10
+ import os
11
+ import sys
12
+ sys.path.append(os.path.abspath('./'))
13
+ from dataclasses import dataclass
14
+ from torch.distributed import init_process_group, destroy_process_group, barrier
15
+ from gdf import GDF_dual_fixlrt as GDF
16
+ from gdf import EpsilonTarget, CosineSchedule
17
+ from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
18
+ from torchtools.transforms import SmartCrop
19
+ from fractions import Fraction
20
+ from modules.effnet import EfficientNetEncoder
21
+
22
+ from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
23
+ from modules.previewer import Previewer
24
+ from core.data import Bucketeer
25
+ from train.base import DataCore, TrainingCore
26
+ from tqdm import tqdm
27
+ from core import WarpCore
28
+ from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
29
+
30
+ from accelerate import init_empty_weights
31
+ from accelerate.utils import set_module_tensor_to_device
32
+ from contextlib import contextmanager
33
+ from train.dist_core import *
34
+ import glob
35
+ from torch.utils.data import DataLoader, Dataset
36
+ from torch.nn.parallel import DistributedDataParallel as DDP
37
+ from torch.utils.data.distributed import DistributedSampler
38
+ from PIL import Image
39
+ from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
40
+ from core.utils import Base
41
+ from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm
42
+ import torch.nn.functional as F
43
+ import functools
44
+ import math
45
+ import copy
46
+ import random
47
+ from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
48
+ Image.MAX_IMAGE_PIXELS = None
49
+ torch.manual_seed(23)
50
+ random.seed(23)
51
+ np.random.seed(23)
52
+ #7978026
53
+
54
+ class Null_Model(torch.nn.Module):
55
+ def __init__(self):
56
+ super().__init__()
57
+ def forward(self, x):
58
+ pass
59
+
60
+
61
+
62
+
63
+ def identity(x):
64
+ if isinstance(x, bytes):
65
+ x = x.decode('utf-8')
66
+ return x
67
+ def check_nan_inmodel(model, meta=''):
68
+ for name, param in model.named_parameters():
69
+ if torch.isnan(param).any():
70
+ print(f"nan detected in {name}", meta)
71
+ return True
72
+ print('no nan', meta)
73
+ return False
74
+ class mydist_dataset(Dataset):
75
+ def __init__(self, rootpath, img_processor=None):
76
+
77
+ self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg'))
78
+ self.img_processor = img_processor
79
+ self.length = len( self.img_pathlist)
80
+
81
+
82
+
83
+ def __getitem__(self, idx):
84
+
85
+ imgpath = self.img_pathlist[idx]
86
+ json_file = imgpath.replace('.jpg', '.json')
87
+
88
+ with open(json_file, 'r') as file:
89
+ info = json.load(file)
90
+ txt = info['caption']
91
+ if txt is None:
92
+ txt = ' '
93
+ try:
94
+ img = Image.open(imgpath).convert('RGB')
95
+ w, h = img.size
96
+ if self.img_processor is not None:
97
+ img = self.img_processor(img)
98
+
99
+ except:
100
+ print('exception', imgpath)
101
+ return self.__getitem__(random.randint(0, self.length -1 ) )
102
+ return dict(captions=txt, images=img)
103
+ def __len__(self):
104
+ return self.length
105
+
106
+ class WurstCore(TrainingCore, DataCore, WarpCore):
107
+ @dataclass(frozen=True)
108
+ class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
109
+ # TRAINING PARAMS
110
+ lr: float = EXPECTED_TRAIN
111
+ warmup_updates: int = EXPECTED_TRAIN
112
+ dtype: str = None
113
+
114
+ # MODEL VERSION
115
+ model_version: str = EXPECTED # 3.6B or 1B
116
+ clip_image_model_name: str = 'openai/clip-vit-large-patch14'
117
+ clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
118
+
119
+ # CHECKPOINT PATHS
120
+ effnet_checkpoint_path: str = EXPECTED
121
+ previewer_checkpoint_path: str = EXPECTED
122
+
123
+ generator_checkpoint_path: str = None
124
+
125
+ # gdf customization
126
+ adaptive_loss_weight: str = None
127
+ use_ddp: bool=EXPECTED
128
+
129
+
130
+ @dataclass(frozen=True)
131
+ class Data(Base):
132
+ dataset: Dataset = EXPECTED
133
+ dataloader: DataLoader = EXPECTED
134
+ iterator: any = EXPECTED
135
+ sampler: DistributedSampler = EXPECTED
136
+
137
+ @dataclass(frozen=True)
138
+ class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
139
+ effnet: nn.Module = EXPECTED
140
+ previewer: nn.Module = EXPECTED
141
+ train_norm: nn.Module = EXPECTED
142
+
143
+
144
+ @dataclass(frozen=True)
145
+ class Schedulers(WarpCore.Schedulers):
146
+ generator: any = None
147
+
148
+ @dataclass(frozen=True)
149
+ class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
150
+ gdf: GDF = EXPECTED
151
+ sampling_configs: dict = EXPECTED
152
+ effnet_preprocess: torchvision.transforms.Compose = EXPECTED
153
+
154
+ info: TrainingCore.Info
155
+ config: Config
156
+
157
+ def setup_extras_pre(self) -> Extras:
158
+ gdf = GDF(
159
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
160
+ input_scaler=VPScaler(), target=EpsilonTarget(),
161
+ noise_cond=CosineTNoiseCond(),
162
+ loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
163
+ )
164
+ sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
165
+
166
+ if self.info.adaptive_loss is not None:
167
+ gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
168
+ gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
169
+
170
+ effnet_preprocess = torchvision.transforms.Compose([
171
+ torchvision.transforms.Normalize(
172
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
173
+ )
174
+ ])
175
+
176
+ clip_preprocess = torchvision.transforms.Compose([
177
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
178
+ torchvision.transforms.CenterCrop(224),
179
+ torchvision.transforms.Normalize(
180
+ mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
181
+ )
182
+ ])
183
+
184
+ if self.config.training:
185
+ transforms = torchvision.transforms.Compose([
186
+ torchvision.transforms.ToTensor(),
187
+ torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
188
+ SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
189
+ ])
190
+ else:
191
+ transforms = None
192
+
193
+ return self.Extras(
194
+ gdf=gdf,
195
+ sampling_configs=sampling_configs,
196
+ transforms=transforms,
197
+ effnet_preprocess=effnet_preprocess,
198
+ clip_preprocess=clip_preprocess
199
+ )
200
+
201
+ def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
202
+ eval_image_embeds=False, return_fields=None):
203
+ conditions = super().get_conditions(
204
+ batch, models, extras, is_eval, is_unconditional,
205
+ eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
206
+ )
207
+ return conditions
208
+
209
+ def setup_models(self, extras: Extras) -> Models: # configure model
210
+
211
+ dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
212
+
213
+ # EfficientNet encoderin
214
+ effnet = EfficientNetEncoder()
215
+ effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
216
+ effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
217
+ effnet.eval().requires_grad_(False).to(self.device)
218
+ del effnet_checkpoint
219
+
220
+ # Previewer
221
+ previewer = Previewer()
222
+ previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
223
+ previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
224
+ previewer.eval().requires_grad_(False).to(self.device)
225
+ del previewer_checkpoint
226
+
227
+ @contextmanager
228
+ def dummy_context():
229
+ yield None
230
+
231
+ loading_context = dummy_context if self.config.training else init_empty_weights
232
+
233
+ # Diffusion models
234
+ with loading_context():
235
+ generator_ema = None
236
+ if self.config.model_version == '3.6B':
237
+ generator = StageC()
238
+ if self.config.ema_start_iters is not None: # default setting
239
+ generator_ema = StageC()
240
+ elif self.config.model_version == '1B':
241
+ print('in line 155 1b light model', self.config.model_version )
242
+ generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
243
+
244
+ if self.config.ema_start_iters is not None and self.config.training:
245
+ generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
246
+ else:
247
+ raise ValueError(f"Unknown model version {self.config.model_version}")
248
+
249
+
250
+
251
+ if loading_context is dummy_context:
252
+ generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
253
+ else:
254
+ for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
255
+ set_module_tensor_to_device(generator, param_name, "cpu", value=param)
256
+
257
+ generator._init_extra_parameter()
258
+ generator = generator.to(torch.bfloat16).to(self.device)
259
+
260
+
261
+ train_norm = nn.ModuleList()
262
+ cnt_norm = 0
263
+ for mm in generator.modules():
264
+ if isinstance(mm, GlobalResponseNorm):
265
+
266
+ train_norm.append(Null_Model())
267
+ cnt_norm += 1
268
+
269
+ train_norm.append(generator.agg_net)
270
+ train_norm.append(generator.agg_net_up)
271
+ total = sum([ param.nelement() for param in train_norm.parameters()])
272
+ print('Trainable parameter', total / 1048576)
273
+
274
+ if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
275
+ sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
276
+ collect_sd = {}
277
+ for k, v in sdd.items():
278
+ collect_sd[k[7:]] = v
279
+ train_norm.load_state_dict(collect_sd, strict=True)
280
+
281
+
282
+ train_norm.to(self.device).train().requires_grad_(True)
283
+
284
+ if generator_ema is not None:
285
+
286
+ generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
287
+ generator_ema._init_extra_parameter()
288
+
289
+
290
+ pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
291
+ if os.path.exists(pretrained_pth):
292
+ print(pretrained_pth, 'exists')
293
+ generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
294
+
295
+
296
+ generator_ema.eval().requires_grad_(False)
297
+
298
+
299
+
300
+
301
+ check_nan_inmodel(generator, 'generator')
302
+
303
+
304
+
305
+ if self.config.use_ddp and self.config.training:
306
+
307
+ train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
308
+
309
+ # CLIP encoders
310
+ tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
311
+ text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
312
+ image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
313
+
314
+ return self.Models(
315
+ effnet=effnet, previewer=previewer, train_norm = train_norm,
316
+ generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model,
317
+ )
318
+
319
+ def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
320
+
321
+
322
+ params = []
323
+ params += list(models.train_norm.module.parameters())
324
+
325
+ optimizer = optim.AdamW(params, lr=self.config.lr)
326
+
327
+ return self.Optimizers(generator=optimizer)
328
+
329
+ def ema_update(self, ema_model, source_model, beta):
330
+ for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
331
+ param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
332
+
333
+ def sync_ema(self, ema_model):
334
+ for param in ema_model.parameters():
335
+ torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
336
+ param.data /= torch.distributed.get_world_size()
337
+ def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
338
+
339
+
340
+ optimizer = optim.AdamW(
341
+ models.generator.up_blocks.parameters() ,
342
+ lr=self.config.lr)
343
+ optimizer = self.load_optimizer(optimizer, 'generator_optim',
344
+ fsdp_model=models.generator if self.config.use_fsdp else None)
345
+ return self.Optimizers(generator=optimizer)
346
+
347
+ def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
348
+ scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
349
+ scheduler.last_epoch = self.info.total_steps
350
+ return self.Schedulers(generator=scheduler)
351
+
352
+ def setup_data(self, extras: Extras) -> WarpCore.Data:
353
+ # SETUP DATASET
354
+ dataset_path = self.config.webdataset_path
355
+ dataset = mydist_dataset(dataset_path, \
356
+ torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
357
+ else extras.transforms)
358
+
359
+ # SETUP DATALOADER
360
+ real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
361
+
362
+ sampler = DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
363
+ dataloader = DataLoader(
364
+ dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
365
+ collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
366
+ sampler = sampler
367
+ )
368
+ if self.is_main_node:
369
+ print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
370
+
371
+ if self.config.multi_aspect_ratio is not None:
372
+ aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
373
+ dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
374
+ ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
375
+ interpolate_nearest=False) # , use_smartcrop=True)
376
+ else:
377
+
378
+ dataloader_iterator = iter(dataloader)
379
+
380
+ return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
381
+
382
+
383
+ def models_to_save(self):
384
+ pass
385
+ def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
386
+
387
+ if not single_gpu:
388
+ local_rank = rank
389
+ process_id = rank
390
+ world_size = get_world_size()
391
+
392
+ self.process_id = process_id
393
+ self.is_main_node = process_id == 0
394
+ self.device = torch.device(local_rank)
395
+ self.world_size = world_size
396
+
397
+ os.environ['MASTER_ADDR'] = 'localhost'
398
+ os.environ['MASTER_PORT'] = '41443'
399
+ torch.cuda.set_device(local_rank)
400
+ init_process_group(
401
+ backend="nccl",
402
+ rank=local_rank,
403
+ world_size=world_size,
404
+ )
405
+ print(f"[GPU {process_id}] READY")
406
+ else:
407
+ self.is_main_node = rank == 0
408
+ self.process_id = rank
409
+ self.device = torch.device('cuda:0')
410
+ self.world_size = 1
411
+ print("Running in single thread, DDP not enabled.")
412
+ # Training loop --------------------------------
413
+ def get_target_lr_size(self, ratio, std_size=24):
414
+ w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
415
+ return (h * 32 , w * 32)
416
+ def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
417
+ #batch = next(data.iterator)
418
+ batch = data
419
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
420
+ shape_lr = self.get_target_lr_size(ratio)
421
+ #print('in line 485', shape_lr, ratio, batch['images'].shape)
422
+ with torch.no_grad():
423
+ conditions = self.get_conditions(batch, models, extras)
424
+
425
+ latents = self.encode_latents(batch, models, extras)
426
+ latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
427
+
428
+ noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
429
+ noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
430
+
431
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
432
+ # 768 1536
433
+ require_cond = True
434
+
435
+ with torch.no_grad():
436
+ _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
437
+
438
+
439
+ pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)
440
+ loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
441
+
442
+ loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps
443
+
444
+
445
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
446
+ extras.gdf.loss_weight.update_buckets(logSNR, loss)
447
+
448
+ return loss, loss_adjusted
449
+
450
+ def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
451
+
452
+
453
+ if update:
454
+
455
+ torch.distributed.barrier()
456
+ loss_adjusted.backward()
457
+
458
+ grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
459
+
460
+ optimizers_dict = optimizers.to_dict()
461
+ for k in optimizers_dict:
462
+ if k != 'training':
463
+ optimizers_dict[k].step()
464
+ schedulers_dict = schedulers.to_dict()
465
+ for k in schedulers_dict:
466
+ if k != 'training':
467
+ schedulers_dict[k].step()
468
+ for k in optimizers_dict:
469
+ if k != 'training':
470
+ optimizers_dict[k].zero_grad(set_to_none=True)
471
+ self.info.total_steps += 1
472
+ else:
473
+
474
+ loss_adjusted.backward()
475
+
476
+ grad_norm = torch.tensor(0.0).to(self.device)
477
+
478
+ return grad_norm
479
+
480
+
481
+ def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
482
+
483
+ images = batch['images'].to(self.device)
484
+ if target_size is not None:
485
+ images = F.interpolate(images, target_size)
486
+
487
+ return models.effnet(extras.effnet_preprocess(images))
488
+
489
+ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
490
+ return models.previewer(latents)
491
+
492
+ def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
493
+
494
+ self.is_main_node = (rank == 0)
495
+ self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
496
+ self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
497
+ self.info: self.Info = self.setup_info()
498
+
499
+
500
+
501
+ def __call__(self, single_gpu=False):
502
+
503
+ if self.config.allow_tf32:
504
+ torch.backends.cuda.matmul.allow_tf32 = True
505
+ torch.backends.cudnn.allow_tf32 = True
506
+
507
+ if self.is_main_node:
508
+ print()
509
+ print("**STARTIG JOB WITH CONFIG:**")
510
+ print(yaml.dump(self.config.to_dict(), default_flow_style=False))
511
+ print("------------------------------------")
512
+ print()
513
+ print("**INFO:**")
514
+ print(yaml.dump(vars(self.info), default_flow_style=False))
515
+ print("------------------------------------")
516
+ print()
517
+
518
+ # SETUP STUFF
519
+ extras = self.setup_extras_pre()
520
+ assert extras is not None, "setup_extras_pre() must return a DTO"
521
+
522
+
523
+
524
+ data = self.setup_data(extras)
525
+ assert data is not None, "setup_data() must return a DTO"
526
+ if self.is_main_node:
527
+ print("**DATA:**")
528
+ print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
529
+ print("------------------------------------")
530
+ print()
531
+
532
+ models = self.setup_models(extras)
533
+ assert models is not None, "setup_models() must return a DTO"
534
+ if self.is_main_node:
535
+ print("**MODELS:**")
536
+ print(yaml.dump({
537
+ k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
538
+ }, default_flow_style=False))
539
+ print("------------------------------------")
540
+ print()
541
+
542
+
543
+
544
+ optimizers = self.setup_optimizers(extras, models)
545
+ assert optimizers is not None, "setup_optimizers() must return a DTO"
546
+ if self.is_main_node:
547
+ print("**OPTIMIZERS:**")
548
+ print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
549
+ print("------------------------------------")
550
+ print()
551
+
552
+ schedulers = self.setup_schedulers(extras, models, optimizers)
553
+ assert schedulers is not None, "setup_schedulers() must return a DTO"
554
+ if self.is_main_node:
555
+ print("**SCHEDULERS:**")
556
+ print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
557
+ print("------------------------------------")
558
+ print()
559
+
560
+ post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
561
+ assert post_extras is not None, "setup_extras_post() must return a DTO"
562
+ extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
563
+ if self.is_main_node:
564
+ print("**EXTRAS:**")
565
+ print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
566
+ print("------------------------------------")
567
+ print()
568
+ # -------
569
+
570
+ # TRAIN
571
+ if self.is_main_node:
572
+ print("**TRAINING STARTING...**")
573
+ self.train(data, extras, models, optimizers, schedulers)
574
+
575
+ if single_gpu is False:
576
+ barrier()
577
+ destroy_process_group()
578
+ if self.is_main_node:
579
+ print()
580
+ print("------------------------------------")
581
+ print()
582
+ print("**TRAINING COMPLETE**")
583
+
584
+
585
+
586
+ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
587
+ schedulers: WarpCore.Schedulers):
588
+ start_iter = self.info.iter + 1
589
+ max_iters = self.config.updates * self.config.grad_accum_steps
590
+ if self.is_main_node:
591
+ print(f"STARTING AT STEP: {start_iter}/{max_iters}")
592
+
593
+
594
+ if self.is_main_node:
595
+ create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
596
+
597
+ models.generator.train()
598
+
599
+ iter_cnt = 0
600
+ epoch_cnt = 0
601
+ models.train_norm.train()
602
+ while True:
603
+ epoch_cnt += 1
604
+ if self.world_size > 1:
605
+
606
+ data.sampler.set_epoch(epoch_cnt)
607
+ for ggg in range(len(data.dataloader)):
608
+ iter_cnt += 1
609
+ loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
610
+ grad_norm = self.backward_pass(
611
+ iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
612
+ models, optimizers, schedulers
613
+ )
614
+
615
+ self.info.iter = iter_cnt
616
+
617
+
618
+ # UPDATE LOSS METRICS
619
+ self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
620
+
621
+ #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
622
+ if self.is_main_node and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
623
+ print(f" NaN value encountered in training run {self.info.wandb_run_id}", \
624
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
625
+
626
+ if self.is_main_node:
627
+ logs = {
628
+ 'loss': self.info.ema_loss,
629
+ 'backward_loss': loss_adjusted.mean().item(),
630
+ 'ema_loss': self.info.ema_loss,
631
+ 'raw_ori_loss': loss.mean().item(),
632
+ 'grad_norm': grad_norm.item(),
633
+ 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
634
+ 'total_steps': self.info.total_steps,
635
+ }
636
+ if iter_cnt % (self.config.save_every) == 0:
637
+
638
+ print(iter_cnt, max_iters, logs, epoch_cnt, )
639
+
640
+
641
+
642
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every ) == 0 or iter_cnt == max_iters:
643
+
644
+ # SAVE AND CHECKPOINT STUFF
645
+ if np.isnan(loss.mean().item()):
646
+ if self.is_main_node and self.config.wandb_project is not None:
647
+ print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
648
+ f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
649
+
650
+ else:
651
+ if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
652
+ self.info.adaptive_loss = {
653
+ 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
654
+ 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
655
+ }
656
+
657
+
658
+
659
+ if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
660
+ print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
661
+ torch.save(models.train_norm.state_dict(), \
662
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
663
+
664
+ torch.save(models.train_norm.state_dict(), \
665
+ f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
666
+
667
+
668
+ if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
669
+
670
+ if self.is_main_node:
671
+
672
+ self.sample(models, data, extras)
673
+
674
+
675
+ if self.info.iter >= max_iters:
676
+ break
677
+
678
+ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
679
+
680
+
681
+ models.generator.eval()
682
+ models.train_norm.eval()
683
+ with torch.no_grad():
684
+ batch = next(data.iterator)
685
+ ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
686
+
687
+ shape_lr = self.get_target_lr_size(ratio)
688
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
689
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
690
+
691
+ latents = self.encode_latents(batch, models, extras)
692
+ latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
693
+
694
+
695
+ if self.is_main_node:
696
+
697
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
698
+
699
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
700
+ models.generator, conditions,
701
+ latents.shape, latents_lr.shape,
702
+ unconditions, device=self.device, **extras.sampling_configs
703
+ )
704
+
705
+
706
+
707
+
708
+ if self.is_main_node:
709
+ print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, )
710
+ noised_images = torch.cat(
711
+ [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
712
+
713
+ sampled_images = torch.cat(
714
+ [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
715
+
716
+
717
+ noised_images_lr = torch.cat(
718
+ [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
719
+
720
+ sampled_images_lr = torch.cat(
721
+ [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
722
+
723
+ images = batch['images']
724
+ if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
725
+ images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
726
+ images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
727
+
728
+ collage_img = torch.cat([
729
+ torch.cat([i for i in images.cpu()], dim=-1),
730
+ torch.cat([i for i in noised_images.cpu()], dim=-1),
731
+ torch.cat([i for i in sampled_images.cpu()], dim=-1),
732
+ ], dim=-2)
733
+
734
+ collage_img_lr = torch.cat([
735
+ torch.cat([i for i in images_lr.cpu()], dim=-1),
736
+ torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
737
+ torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
738
+ ], dim=-2)
739
+
740
+ torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
741
+ torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
742
+
743
+
744
+ models.generator.train()
745
+ models.train_norm.train()
746
+ print('finish sampling')
747
+
748
+
749
+
750
+ def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
751
+
752
+
753
+ models.generator.eval()
754
+
755
+ with torch.no_grad():
756
+
757
+ if self.is_main_node:
758
+ conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
759
+ unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
760
+
761
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
762
+
763
+ *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
764
+ models.generator, conditions,
765
+ hr_shape, lr_shape,
766
+ unconditions, device=self.device, **extras.sampling_configs
767
+ )
768
+
769
+ if models.generator_ema is not None:
770
+
771
+ *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
772
+ models.generator_ema, conditions,
773
+ latents.shape, latents_lr.shape,
774
+ unconditions, device=self.device, **extras.sampling_configs
775
+ )
776
+
777
+ else:
778
+ sampled_ema = sampled
779
+ sampled_ema_lr = sampled_lr
780
+
781
+ return sampled, sampled_lr
782
+ def main_worker(rank, cfg):
783
+ print("Launching Script in main worker")
784
+
785
+ warpcore = WurstCore(
786
+ config_file_path=cfg, rank=rank, world_size = get_world_size()
787
+ )
788
+ # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
789
+
790
+ # RUN TRAINING
791
+ warpcore(get_world_size()==1)
792
+
793
+ if __name__ == '__main__':
794
+ print('launch multi process')
795
+ # os.environ["OMP_NUM_THREADS"] = "1"
796
+ # os.environ["MKL_NUM_THREADS"] = "1"
797
+ #dist.init_process_group(backend="nccl")
798
+ #torch.backends.cudnn.benchmark = True
799
+ #train/train_c_my.py
800
+ #mp.set_sharing_strategy('file_system')
801
+
802
+ if get_master_ip() == "127.0.0.1":
803
+ # manually launch distributed processes
804
+ mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
805
+ else:
806
+ main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )