TejasNavada commited on
Commit
c430be3
1 Parent(s): c05fef2

Upload Revised.ipynb

Browse files
Files changed (1) hide show
  1. Revised.ipynb +546 -0
Revised.ipynb ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "id": "5xhZBPJobvEm"
8
+ },
9
+ "outputs": [],
10
+ "source": [
11
+ "!pip install git+https://github.com/huggingface/diffusers.git\n",
12
+ "!pip install git+https://github.com/huggingface/accelerate\n",
13
+ "!pip install --upgrade transformers"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "id": "KuhLUa51fQfE"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "\n",
25
+ "!pip install datasets\n",
26
+ "\n",
27
+ "\n",
28
+ "!pip install torchvision\n",
29
+ "!sudo apt -qq install git-lfs\n",
30
+ "!git config --global credential.helper store\n",
31
+ "!pip install tqdm\n",
32
+ "!pip install bitsandbytes\n",
33
+ "!pip install torch\n",
34
+ "!pip install torchvision"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "t6BleLJZgKR0"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "from dataclasses import dataclass\n",
46
+ "from datasets import load_dataset\n",
47
+ "from torchvision import transforms\n",
48
+ "from accelerate.state import AcceleratorState\n",
49
+ "import math\n",
50
+ "import os\n",
51
+ "import numpy as np\n",
52
+ "import accelerate\n",
53
+ "from accelerate import Accelerator\n",
54
+ "from tqdm.auto import tqdm\n",
55
+ "from pathlib import Path\n",
56
+ "from accelerate import notebook_launcher\n",
57
+ "import torch.nn.functional as F\n",
58
+ "from diffusers.optimization import get_cosine_schedule_with_warmup\n",
59
+ "import torch\n",
60
+ "from PIL import Image\n",
61
+ "from diffusers import UNet2DModel\n",
62
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
63
+ "from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n",
64
+ "from diffusers.optimization import get_scheduler\n",
65
+ "from huggingface_hub import create_repo, upload_folder, upload_file\n",
66
+ "import bitsandbytes as bnb\n",
67
+ "from transformers.utils import ContextManagers\n",
68
+ "from huggingface_hub import snapshot_download\n",
69
+ "\n",
70
+ "\n",
71
+ "@dataclass\n",
72
+ "class TrainingConfig:\n",
73
+ " pretrained_model_name_or_path = \"runwayml/stable-diffusion-v1-5\"\n",
74
+ " validation_prompts = [\"a dragon on a white background\",\" a fiery skull\", \"a skull\", \"a face\", \"a snake and skull\"]\n",
75
+ " image_size = 512 # the generated image resolution\n",
76
+ " train_batch_size = 2\n",
77
+ " eval_batch_size = 2 # how many images to sample during evaluation\n",
78
+ " num_epochs = 50\n",
79
+ " gradient_accumulation_steps = 1\n",
80
+ " lr_scheduler = \"constant\"\n",
81
+ " learning_rate = 1e-5\n",
82
+ " lr_warmup_steps = 500\n",
83
+ " save_image_epochs = 1\n",
84
+ " save_model_epochs = 1\n",
85
+ " token = \"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\"\n",
86
+ " num_processes = 1\n",
87
+ " mixed_precision = \"fp16\" # `no` for float32, `fp16` for automatic mixed precision\n",
88
+ " output_dir = \"tattoo-diffusion\" # the model name locally and on the HF Hub\n",
89
+ "\n",
90
+ " push_to_hub = True # whether to upload the saved model to the HF Hub\n",
91
+ " hub_private_repo = False\n",
92
+ " overwrite_output_dir = True # overwrite the old model when re-running the notebook\n",
93
+ " seed = 0\n",
94
+ "\n",
95
+ "\n",
96
+ "config = TrainingConfig()"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "id": "yBKWnM2p_qI6"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "snapshot_download(repo_id=\"TejasNavada/tattoo-diffusion\", local_dir=config.output_dir, local_dir_use_symlinks=False )"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {
114
+ "id": "GI92xkd-jy7C"
115
+ },
116
+ "outputs": [],
117
+ "source": [
118
+ "\n",
119
+ "\n",
120
+ "def make_grid(images, rows, cols):\n",
121
+ " w, h = images[0].size\n",
122
+ " grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n",
123
+ " for i, image in enumerate(images):\n",
124
+ " grid.paste(image, box=(i % cols * w, i // cols * h))\n",
125
+ " return grid\n",
126
+ "\n",
127
+ "\n",
128
+ "def evaluate(vae, text_encoder, tokenizer, unet, config, accelerator, epoch):\n",
129
+ " pipeline = StableDiffusionPipeline.from_pretrained(\n",
130
+ " config.pretrained_model_name_or_path,\n",
131
+ " vae=accelerator.unwrap_model(vae),\n",
132
+ " text_encoder=accelerator.unwrap_model(text_encoder),\n",
133
+ " tokenizer=tokenizer,\n",
134
+ " unet=accelerator.unwrap_model(unet),\n",
135
+ " safety_checker=None,\n",
136
+ " torch_dtype=torch.float16,\n",
137
+ " )\n",
138
+ "\n",
139
+ " pipeline = pipeline.to(accelerator.device)\n",
140
+ " pipeline.set_progress_bar_config(disable=True)\n",
141
+ "\n",
142
+ " generator = torch.Generator(device=accelerator.device).manual_seed(config.seed)\n",
143
+ "\n",
144
+ " images = []\n",
145
+ "\n",
146
+ " for i in range(len(config.validation_prompts)):\n",
147
+ " with torch.autocast(\"cuda\"):\n",
148
+ " image = pipeline(config.validation_prompts[i], num_inference_steps=20, generator=None).images[0]\n",
149
+ "\n",
150
+ " images.append(image)\n",
151
+ "\n",
152
+ " for tracker in accelerator.trackers:\n",
153
+ " if tracker.name == \"tensorboard\":\n",
154
+ " np_images = np.stack([np.asarray(img) for img in images])\n",
155
+ " tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n",
156
+ "\n",
157
+ " del pipeline\n",
158
+ " torch.cuda.empty_cache()\n",
159
+ "\n",
160
+ " image_grid = make_grid(images, rows=1, cols=len(images))\n",
161
+ "\n",
162
+ " test_dir = os.path.join(config.output_dir, \"samples\")\n",
163
+ " os.makedirs(test_dir, exist_ok=True)\n",
164
+ " image_grid.save(f\"{test_dir}/{epoch:04d}.png\")\n",
165
+ "\n",
166
+ " return images\n",
167
+ "\n",
168
+ "\n",
169
+ "\n"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {
176
+ "id": "kh-C1RIAgRMV"
177
+ },
178
+ "outputs": [],
179
+ "source": [
180
+ "\n",
181
+ "\n",
182
+ "config.dataset_name = \"Drozdik/tattoo_v3\"\n",
183
+ "dataset = load_dataset(config.dataset_name, split=\"train\")\n",
184
+ "tokenizer = CLIPTokenizer.from_pretrained(\n",
185
+ " config.pretrained_model_name_or_path, subfolder=\"tokenizer\",\n",
186
+ " )\n",
187
+ "preprocess = transforms.Compose(\n",
188
+ " [\n",
189
+ " transforms.Resize((config.image_size, config.image_size)),\n",
190
+ " transforms.RandomHorizontalFlip(),\n",
191
+ " transforms.ToTensor(),\n",
192
+ " transforms.Normalize([.5],[.5]),\n",
193
+ " ]\n",
194
+ ")\n",
195
+ "\n",
196
+ "def tokenize_captions(examples):\n",
197
+ " captions = examples[\"text\"]\n",
198
+ " inputs = tokenizer(\n",
199
+ " captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n",
200
+ " )\n",
201
+ " return inputs.input_ids\n",
202
+ "\n",
203
+ "\n",
204
+ "\n",
205
+ "def transform(examples):\n",
206
+ " images = [preprocess(image.convert(\"RGB\")) for image in examples[\"image\"]]\n",
207
+ " examples[\"pixel_values\"] = images\n",
208
+ " examples[\"input_ids\"] = tokenize_captions(examples)\n",
209
+ " return examples"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {
216
+ "id": "MVNCvm8nIiQd"
217
+ },
218
+ "outputs": [],
219
+ "source": [
220
+ "def collate_fn(examples):\n",
221
+ " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
222
+ " pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n",
223
+ " input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n",
224
+ " return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "metadata": {
231
+ "id": "43Z-VBpQi5Yt"
232
+ },
233
+ "outputs": [],
234
+ "source": [
235
+ "def save_model_card(args,repo_id: str,images=None,repo_folder=None):\n",
236
+ " img_str = \"\"\n",
237
+ " if images is not None and len(images) > 0:\n",
238
+ " image_grid = make_grid(images, 1, len(config.validation_prompts))\n",
239
+ " image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n",
240
+ " img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n",
241
+ " yaml = f\"\"\"\n",
242
+ "---\n",
243
+ "license: creativeml-openrail-m\n",
244
+ "base_model: {config.pretrained_model_name_or_path}\n",
245
+ "datasets:\n",
246
+ "- {config.dataset_name}\n",
247
+ "tags:\n",
248
+ "- stable-diffusion\n",
249
+ "- stable-diffusion-diffusers\n",
250
+ "- text-to-image\n",
251
+ "- diffusers\n",
252
+ "inference: true\n",
253
+ "---\n",
254
+ " \"\"\"\n",
255
+ " model_card = f\"\"\"\n",
256
+ "# Text-to-image finetuning - {repo_id}\n",
257
+ "\n",
258
+ "This pipeline was finetuned from **{config.pretrained_model_name_or_path}** on the **{config.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {config.validation_prompts}: \\n\n",
259
+ "{img_str}\n",
260
+ "\n",
261
+ "## Pipeline usage\n",
262
+ "\n",
263
+ "You can use the pipeline like so:\n",
264
+ "\n",
265
+ "```python\n",
266
+ "from diffusers import DiffusionPipeline\n",
267
+ "import torch\n",
268
+ "\n",
269
+ "pipeline = DiffusionPipeline.from_pretrained(\"{repo_id}\", torch_dtype=torch.float16)\n",
270
+ "prompt = \"{config.validation_prompts[0]}\"\n",
271
+ "image = pipeline(prompt).images[0]\n",
272
+ "image.save(\"my_image.png\")\n",
273
+ "```\n",
274
+ "\n",
275
+ "## Training info\n",
276
+ "\n",
277
+ "These are the key hyperparameters used during training:\n",
278
+ "\n",
279
+ "* Epochs: {config.num_epochs}\n",
280
+ "* Learning rate: {config.learning_rate}\n",
281
+ "* Batch size: {config.train_batch_size}\n",
282
+ "* Image resolution: {config.image_size}\n",
283
+ "* Mixed-precision: {config.mixed_precision}\n",
284
+ "\n",
285
+ "\"\"\"\n",
286
+ " with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n",
287
+ " f.write(yaml + model_card)\n",
288
+ "\n"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {
295
+ "id": "VbgnI0pJtsFQ"
296
+ },
297
+ "outputs": [],
298
+ "source": [
299
+ "def deepspeed_zero_init_disabled_context_manager():\n",
300
+ " \"\"\"\n",
301
+ " returns either a context list that includes one that will disable zero.Init or an empty context list\n",
302
+ " \"\"\"\n",
303
+ " deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n",
304
+ " if deepspeed_plugin is None:\n",
305
+ " return []\n",
306
+ "\n",
307
+ " return [deepspeed_plugin.zero3_init_context_manager(enable=False)]"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {
314
+ "id": "c6162g9pLz5r"
315
+ },
316
+ "outputs": [],
317
+ "source": [
318
+ "def train_loop(config, unet, vae, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n",
319
+ " repo_id = \"TejasNavada/tattoo-diffusion\"\n",
320
+ "\n",
321
+ " accelerator = Accelerator(\n",
322
+ " mixed_precision=config.mixed_precision,\n",
323
+ " gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
324
+ " log_with=\"tensorboard\",\n",
325
+ " project_dir=os.path.join(config.output_dir, \"logs\"),\n",
326
+ " )\n",
327
+ " state_dict = lr_scheduler.state_dict()\n",
328
+ " print(state_dict)\n",
329
+ " if accelerator.is_main_process:\n",
330
+ " os.makedirs(config.output_dir,exist_ok=True)\n",
331
+ " accelerator.init_trackers(\"train_example\")\n",
332
+ "\n",
333
+ " unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n",
334
+ " unet, optimizer, train_dataloader, lr_scheduler\n",
335
+ " )\n",
336
+ "\n",
337
+ "\n",
338
+ " text_encoder.to(accelerator.device, dtype=torch.float16)\n",
339
+ " vae.to(accelerator.device, dtype=torch.float16)\n",
340
+ " global_step = 0\n",
341
+ "\n",
342
+ " if(True):\n",
343
+ "\n",
344
+ " dirs = os.listdir(config.output_dir)\n",
345
+ " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
346
+ " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
347
+ " path = dirs[-1] if len(dirs) > 0 else None\n",
348
+ " accelerator.print(f\"Resuming from checkpoint {path}\")\n",
349
+ " accelerator.load_state(os.path.join(config.output_dir, path))\n",
350
+ " global_step = int(path.split(\"-\")[1])\n",
351
+ "\n",
352
+ " start_epoch = global_step//len(train_dataloader)\n",
353
+ "\n",
354
+ " lr_scheduler.load_state_dict(state_dict)\n",
355
+ " print(lr_scheduler.get_last_lr())\n",
356
+ "\n",
357
+ " for epoch in range(start_epoch, config.num_epochs):\n",
358
+ " unet.train()\n",
359
+ "\n",
360
+ " progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n",
361
+ " progress_bar.set_description(f\"Epoch {epoch}\")\n",
362
+ "\n",
363
+ " for step, batch in enumerate(train_dataloader):\n",
364
+ "\n",
365
+ " # Convert images to latent space\n",
366
+ " latents = vae.encode(batch[\"pixel_values\"].to(torch.float16)).latent_dist.sample()\n",
367
+ " latents = latents * vae.config.scaling_factor\n",
368
+ "\n",
369
+ " # Sample noise that to add to the latents\n",
370
+ " noise = torch.randn_like(latents)\n",
371
+ "\n",
372
+ " bsz = latents.shape[0]\n",
373
+ "\n",
374
+ " # Sample a random timestep for each image\n",
375
+ " timesteps = torch.randint(\n",
376
+ " 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n",
377
+ " ).long()\n",
378
+ " # Add noise to the latents according to the noise magnitude at each timestep\n",
379
+ " noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n",
380
+ " # Get the text embedding for conditioning\n",
381
+ " encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n",
382
+ " # Predict the noise residual and compute loss\n",
383
+ " with accelerator.accumulate(unet):\n",
384
+ "\n",
385
+ " model_pred = unet(noisy_latents,timesteps,encoder_hidden_states).sample\n",
386
+ "\n",
387
+ " loss = F.mse_loss(model_pred.float(),noise.float(), reduction=\"mean\")\n",
388
+ "\n",
389
+ " # Backpropagate\n",
390
+ " accelerator.backward(loss)\n",
391
+ " accelerator.clip_grad_norm_(unet.parameters(),1.0)\n",
392
+ "\n",
393
+ " optimizer.step()\n",
394
+ " lr_scheduler.step()\n",
395
+ " optimizer.zero_grad()\n",
396
+ "\n",
397
+ " progress_bar.update(1)\n",
398
+ " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n",
399
+ " progress_bar.set_postfix(**logs)\n",
400
+ " accelerator.log(logs, step=global_step)\n",
401
+ " global_step += 1\n",
402
+ "\n",
403
+ " if accelerator.is_main_process:\n",
404
+ "\n",
405
+ " if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n",
406
+ " images = evaluate(vae, text_encoder, tokenizer, unet, config, accelerator, epoch)\n",
407
+ " save_path = os.path.join(config.output_dir, f\"checkpoint-{global_step}\")\n",
408
+ " accelerator.save_state(save_path)\n",
409
+ " save_model_card(config, repo_id, images, repo_folder=config.output_dir)\n",
410
+ " upload_folder(\n",
411
+ " repo_id=repo_id,\n",
412
+ " folder_path=save_path,\n",
413
+ " path_in_repo=f\"checkpoint-{global_step}\",\n",
414
+ " commit_message=\"Latest Checkpoint\",\n",
415
+ " ignore_patterns=[\"step_*\", \"epoch_*\"],\n",
416
+ " )\n",
417
+ " upload_folder(\n",
418
+ " repo_id=repo_id,\n",
419
+ " folder_path=os.path.join(config.output_dir, \"samples\"),\n",
420
+ " path_in_repo=\"samples\",\n",
421
+ " commit_message=\"new samples\",\n",
422
+ " ignore_patterns=[\"step_*\", \"epoch_*\"],\n",
423
+ " )\n",
424
+ " upload_file(\n",
425
+ " path_or_fileobj=os.path.join(config.output_dir, \"README.md\"),\n",
426
+ " path_in_repo=\"README.md\",\n",
427
+ " repo_id=repo_id,\n",
428
+ " )\n",
429
+ "\n",
430
+ " unet = accelerator.unwrap_model(unet)\n",
431
+ " pipeline = StableDiffusionPipeline.from_pretrained(\n",
432
+ " config.pretrained_model_name_or_path,\n",
433
+ " text_encoder=text_encoder,\n",
434
+ " vae=vae,\n",
435
+ " unet=unet,\n",
436
+ " )\n",
437
+ " pipeline.save_pretrained(config.output_dir)\n",
438
+ " accelerator.end_training()\n",
439
+ "\n",
440
+ "\n",
441
+ "\n",
442
+ "\n",
443
+ "\n",
444
+ "\n",
445
+ "\n",
446
+ "\n",
447
+ "\n"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "metadata": {
454
+ "id": "L21-Cx7NrghU"
455
+ },
456
+ "outputs": [],
457
+ "source": [
458
+ "config.validation_prompts[0]"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": null,
464
+ "metadata": {
465
+ "id": "ofrTlboPpwX9"
466
+ },
467
+ "outputs": [],
468
+ "source": [
469
+ "from transformers.utils.hub import huggingface_hub\n",
470
+ "huggingface_hub.login(config.token, add_to_git_credential=True, new_session=True, write_permission=True)"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {
477
+ "id": "3o2O7BkjmNsB"
478
+ },
479
+ "outputs": [],
480
+ "source": [
481
+ "dataset.set_transform(transform)\n",
482
+ "train_dataloader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, batch_size=config.train_batch_size, shuffle=True)\n",
483
+ "noise_scheduler = DDPMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
484
+ "with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n",
485
+ " text_encoder = CLIPTextModel.from_pretrained(\n",
486
+ " config.pretrained_model_name_or_path, subfolder=\"text_encoder\",\n",
487
+ " )\n",
488
+ " vae = AutoencoderKL.from_pretrained(\n",
489
+ " config.pretrained_model_name_or_path, subfolder=\"vae\",\n",
490
+ " )\n",
491
+ "\n",
492
+ "\n",
493
+ "\n",
494
+ "unet = UNet2DConditionModel(\n",
495
+ " sample_size=config.image_size//8,\n",
496
+ " cross_attention_dim = 768,\n",
497
+ " )\n",
498
+ "\n",
499
+ "vae.requires_grad_(False)\n",
500
+ "text_encoder.requires_grad_(False)\n",
501
+ "optimizer = bnb.optim.AdamW8bit(\n",
502
+ " unet.parameters(),\n",
503
+ " lr=config.learning_rate,\n",
504
+ " )\n",
505
+ "lr_scheduler = get_scheduler(\n",
506
+ " config.lr_scheduler,\n",
507
+ " optimizer=optimizer,\n",
508
+ " num_warmup_steps=config.lr_warmup_steps,\n",
509
+ " num_training_steps=(len(train_dataloader)*config.num_epochs),\n",
510
+ ")\n",
511
+ "\n",
512
+ "\n",
513
+ "args = (config, unet, vae, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n",
514
+ "\n",
515
+ "\n"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "source": [
521
+ "notebook_launcher(train_loop, args, num_processes=1)"
522
+ ],
523
+ "metadata": {
524
+ "id": "GCR1zr9EKLyw"
525
+ },
526
+ "execution_count": null,
527
+ "outputs": []
528
+ }
529
+ ],
530
+ "metadata": {
531
+ "accelerator": "GPU",
532
+ "colab": {
533
+ "provenance": [],
534
+ "gpuType": "T4"
535
+ },
536
+ "kernelspec": {
537
+ "display_name": "Python 3",
538
+ "name": "python3"
539
+ },
540
+ "language_info": {
541
+ "name": "python"
542
+ }
543
+ },
544
+ "nbformat": 4,
545
+ "nbformat_minor": 0
546
+ }