{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "5xhZBPJobvEm" }, "outputs": [], "source": [ "!pip install git+https://github.com/huggingface/diffusers.git\n", "!pip install git+https://github.com/huggingface/accelerate\n", "!pip install --upgrade transformers" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KuhLUa51fQfE" }, "outputs": [], "source": [ "\n", "!pip install datasets\n", "\n", "\n", "!pip install torchvision\n", "!sudo apt -qq install git-lfs\n", "!git config --global credential.helper store\n", "!pip install tqdm\n", "!pip install bitsandbytes\n", "!pip install torch\n", "!pip install torchvision" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t6BleLJZgKR0" }, "outputs": [], "source": [ "from dataclasses import dataclass\n", "from datasets import load_dataset\n", "from torchvision import transforms\n", "from accelerate.state import AcceleratorState\n", "import math\n", "import os\n", "import numpy as np\n", "import accelerate\n", "from accelerate import Accelerator\n", "from tqdm.auto import tqdm\n", "from pathlib import Path\n", "from accelerate import notebook_launcher\n", "import torch.nn.functional as F\n", "from diffusers.optimization import get_cosine_schedule_with_warmup\n", "import torch\n", "from PIL import Image\n", "from diffusers import UNet2DModel\n", "from transformers import CLIPTextModel, CLIPTokenizer\n", "from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel\n", "from diffusers.optimization import get_scheduler\n", "from huggingface_hub import create_repo, upload_folder, upload_file\n", "import bitsandbytes as bnb\n", "from transformers.utils import ContextManagers\n", "from huggingface_hub import snapshot_download\n", "\n", "\n", "@dataclass\n", "class TrainingConfig:\n", " pretrained_model_name_or_path = \"runwayml/stable-diffusion-v1-5\"\n", " validation_prompts = [\"a dragon on a white background\",\" a fiery skull\", \"a skull\", \"a face\", \"a snake and skull\"]\n", " image_size = 512 # the generated image resolution\n", " train_batch_size = 2\n", " eval_batch_size = 2 # how many images to sample during evaluation\n", " num_epochs = 50\n", " gradient_accumulation_steps = 1\n", " lr_scheduler = \"constant\"\n", " learning_rate = 1e-5\n", " lr_warmup_steps = 500\n", " save_image_epochs = 1\n", " save_model_epochs = 1\n", " token = \"XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\"\n", " num_processes = 1\n", " mixed_precision = \"fp16\" # `no` for float32, `fp16` for automatic mixed precision\n", " output_dir = \"tattoo-diffusion\" # the model name locally and on the HF Hub\n", "\n", " push_to_hub = True # whether to upload the saved model to the HF Hub\n", " hub_private_repo = False\n", " overwrite_output_dir = True # overwrite the old model when re-running the notebook\n", " seed = 0\n", "\n", "\n", "config = TrainingConfig()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yBKWnM2p_qI6" }, "outputs": [], "source": [ "snapshot_download(repo_id=\"TejasNavada/tattoo-diffusion\", local_dir=config.output_dir, local_dir_use_symlinks=False )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GI92xkd-jy7C" }, "outputs": [], "source": [ "\n", "\n", "def make_grid(images, rows, cols):\n", " w, h = images[0].size\n", " grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n", " for i, image in enumerate(images):\n", " grid.paste(image, box=(i % cols * w, i // cols * h))\n", " return grid\n", "\n", "\n", "def evaluate(vae, text_encoder, tokenizer, unet, config, accelerator, epoch):\n", " pipeline = StableDiffusionPipeline.from_pretrained(\n", " config.pretrained_model_name_or_path,\n", " vae=accelerator.unwrap_model(vae),\n", " text_encoder=accelerator.unwrap_model(text_encoder),\n", " tokenizer=tokenizer,\n", " unet=accelerator.unwrap_model(unet),\n", " safety_checker=None,\n", " torch_dtype=torch.float16,\n", " )\n", "\n", " pipeline = pipeline.to(accelerator.device)\n", " pipeline.set_progress_bar_config(disable=True)\n", "\n", " generator = torch.Generator(device=accelerator.device).manual_seed(config.seed)\n", "\n", " images = []\n", "\n", " for i in range(len(config.validation_prompts)):\n", " with torch.autocast(\"cuda\"):\n", " image = pipeline(config.validation_prompts[i], num_inference_steps=20, generator=None).images[0]\n", "\n", " images.append(image)\n", "\n", " for tracker in accelerator.trackers:\n", " if tracker.name == \"tensorboard\":\n", " np_images = np.stack([np.asarray(img) for img in images])\n", " tracker.writer.add_images(\"validation\", np_images, epoch, dataformats=\"NHWC\")\n", "\n", " del pipeline\n", " torch.cuda.empty_cache()\n", "\n", " image_grid = make_grid(images, rows=1, cols=len(images))\n", "\n", " test_dir = os.path.join(config.output_dir, \"samples\")\n", " os.makedirs(test_dir, exist_ok=True)\n", " image_grid.save(f\"{test_dir}/{epoch:04d}.png\")\n", "\n", " return images\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kh-C1RIAgRMV" }, "outputs": [], "source": [ "\n", "\n", "config.dataset_name = \"Drozdik/tattoo_v3\"\n", "dataset = load_dataset(config.dataset_name, split=\"train\")\n", "tokenizer = CLIPTokenizer.from_pretrained(\n", " config.pretrained_model_name_or_path, subfolder=\"tokenizer\",\n", " )\n", "preprocess = transforms.Compose(\n", " [\n", " transforms.Resize((config.image_size, config.image_size)),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize([.5],[.5]),\n", " ]\n", ")\n", "\n", "def tokenize_captions(examples):\n", " captions = examples[\"text\"]\n", " inputs = tokenizer(\n", " captions, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n", " )\n", " return inputs.input_ids\n", "\n", "\n", "\n", "def transform(examples):\n", " images = [preprocess(image.convert(\"RGB\")) for image in examples[\"image\"]]\n", " examples[\"pixel_values\"] = images\n", " examples[\"input_ids\"] = tokenize_captions(examples)\n", " return examples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MVNCvm8nIiQd" }, "outputs": [], "source": [ "def collate_fn(examples):\n", " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", " pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n", " input_ids = torch.stack([example[\"input_ids\"] for example in examples])\n", " return {\"pixel_values\": pixel_values, \"input_ids\": input_ids}\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "43Z-VBpQi5Yt" }, "outputs": [], "source": [ "def save_model_card(args,repo_id: str,images=None,repo_folder=None):\n", " img_str = \"\"\n", " if images is not None and len(images) > 0:\n", " image_grid = make_grid(images, 1, len(config.validation_prompts))\n", " image_grid.save(os.path.join(repo_folder, \"val_imgs_grid.png\"))\n", " img_str += \"![val_imgs_grid](./val_imgs_grid.png)\\n\"\n", " yaml = f\"\"\"\n", "---\n", "license: creativeml-openrail-m\n", "base_model: {config.pretrained_model_name_or_path}\n", "datasets:\n", "- {config.dataset_name}\n", "tags:\n", "- stable-diffusion\n", "- stable-diffusion-diffusers\n", "- text-to-image\n", "- diffusers\n", "inference: true\n", "---\n", " \"\"\"\n", " model_card = f\"\"\"\n", "# Text-to-image finetuning - {repo_id}\n", "\n", "This pipeline was finetuned from **{config.pretrained_model_name_or_path}** on the **{config.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {config.validation_prompts}: \\n\n", "{img_str}\n", "\n", "## Pipeline usage\n", "\n", "You can use the pipeline like so:\n", "\n", "```python\n", "from diffusers import DiffusionPipeline\n", "import torch\n", "\n", "pipeline = DiffusionPipeline.from_pretrained(\"{repo_id}\", torch_dtype=torch.float16)\n", "prompt = \"{config.validation_prompts[0]}\"\n", "image = pipeline(prompt).images[0]\n", "image.save(\"my_image.png\")\n", "```\n", "\n", "## Training info\n", "\n", "These are the key hyperparameters used during training:\n", "\n", "* Epochs: {config.num_epochs}\n", "* Learning rate: {config.learning_rate}\n", "* Batch size: {config.train_batch_size}\n", "* Image resolution: {config.image_size}\n", "* Mixed-precision: {config.mixed_precision}\n", "\n", "\"\"\"\n", " with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n", " f.write(yaml + model_card)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VbgnI0pJtsFQ" }, "outputs": [], "source": [ "def deepspeed_zero_init_disabled_context_manager():\n", " \"\"\"\n", " returns either a context list that includes one that will disable zero.Init or an empty context list\n", " \"\"\"\n", " deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None\n", " if deepspeed_plugin is None:\n", " return []\n", "\n", " return [deepspeed_plugin.zero3_init_context_manager(enable=False)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c6162g9pLz5r" }, "outputs": [], "source": [ "def train_loop(config, unet, vae, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n", " repo_id = \"TejasNavada/tattoo-diffusion\"\n", "\n", " accelerator = Accelerator(\n", " mixed_precision=config.mixed_precision,\n", " gradient_accumulation_steps=config.gradient_accumulation_steps,\n", " log_with=\"tensorboard\",\n", " project_dir=os.path.join(config.output_dir, \"logs\"),\n", " )\n", " state_dict = lr_scheduler.state_dict()\n", " print(state_dict)\n", " if accelerator.is_main_process:\n", " os.makedirs(config.output_dir,exist_ok=True)\n", " accelerator.init_trackers(\"train_example\")\n", "\n", " unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n", " unet, optimizer, train_dataloader, lr_scheduler\n", " )\n", "\n", "\n", " text_encoder.to(accelerator.device, dtype=torch.float16)\n", " vae.to(accelerator.device, dtype=torch.float16)\n", " global_step = 0\n", "\n", " if(True):\n", "\n", " dirs = os.listdir(config.output_dir)\n", " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n", " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n", " path = dirs[-1] if len(dirs) > 0 else None\n", " accelerator.print(f\"Resuming from checkpoint {path}\")\n", " accelerator.load_state(os.path.join(config.output_dir, path))\n", " global_step = int(path.split(\"-\")[1])\n", "\n", " start_epoch = global_step//len(train_dataloader)\n", "\n", " lr_scheduler.load_state_dict(state_dict)\n", " print(lr_scheduler.get_last_lr())\n", "\n", " for epoch in range(start_epoch, config.num_epochs):\n", " unet.train()\n", "\n", " progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n", " progress_bar.set_description(f\"Epoch {epoch}\")\n", "\n", " for step, batch in enumerate(train_dataloader):\n", "\n", " # Convert images to latent space\n", " latents = vae.encode(batch[\"pixel_values\"].to(torch.float16)).latent_dist.sample()\n", " latents = latents * vae.config.scaling_factor\n", "\n", " # Sample noise that to add to the latents\n", " noise = torch.randn_like(latents)\n", "\n", " bsz = latents.shape[0]\n", "\n", " # Sample a random timestep for each image\n", " timesteps = torch.randint(\n", " 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device\n", " ).long()\n", " # Add noise to the latents according to the noise magnitude at each timestep\n", " noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n", " # Get the text embedding for conditioning\n", " encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n", " # Predict the noise residual and compute loss\n", " with accelerator.accumulate(unet):\n", "\n", " model_pred = unet(noisy_latents,timesteps,encoder_hidden_states).sample\n", "\n", " loss = F.mse_loss(model_pred.float(),noise.float(), reduction=\"mean\")\n", "\n", " # Backpropagate\n", " accelerator.backward(loss)\n", " accelerator.clip_grad_norm_(unet.parameters(),1.0)\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " progress_bar.update(1)\n", " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n", " progress_bar.set_postfix(**logs)\n", " accelerator.log(logs, step=global_step)\n", " global_step += 1\n", "\n", " if accelerator.is_main_process:\n", "\n", " if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n", " images = evaluate(vae, text_encoder, tokenizer, unet, config, accelerator, epoch)\n", " save_path = os.path.join(config.output_dir, f\"checkpoint-{global_step}\")\n", " accelerator.save_state(save_path)\n", " save_model_card(config, repo_id, images, repo_folder=config.output_dir)\n", " upload_folder(\n", " repo_id=repo_id,\n", " folder_path=save_path,\n", " path_in_repo=f\"checkpoint-{global_step}\",\n", " commit_message=\"Latest Checkpoint\",\n", " ignore_patterns=[\"step_*\", \"epoch_*\"],\n", " )\n", " upload_folder(\n", " repo_id=repo_id,\n", " folder_path=os.path.join(config.output_dir, \"samples\"),\n", " path_in_repo=\"samples\",\n", " commit_message=\"new samples\",\n", " ignore_patterns=[\"step_*\", \"epoch_*\"],\n", " )\n", " upload_file(\n", " path_or_fileobj=os.path.join(config.output_dir, \"README.md\"),\n", " path_in_repo=\"README.md\",\n", " repo_id=repo_id,\n", " )\n", "\n", " unet = accelerator.unwrap_model(unet)\n", " pipeline = StableDiffusionPipeline.from_pretrained(\n", " config.pretrained_model_name_or_path,\n", " text_encoder=text_encoder,\n", " vae=vae,\n", " unet=unet,\n", " )\n", " pipeline.save_pretrained(config.output_dir)\n", " accelerator.end_training()\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L21-Cx7NrghU" }, "outputs": [], "source": [ "config.validation_prompts[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ofrTlboPpwX9" }, "outputs": [], "source": [ "from transformers.utils.hub import huggingface_hub\n", "huggingface_hub.login(config.token, add_to_git_credential=True, new_session=True, write_permission=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3o2O7BkjmNsB" }, "outputs": [], "source": [ "dataset.set_transform(transform)\n", "train_dataloader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, batch_size=config.train_batch_size, shuffle=True)\n", "noise_scheduler = DDPMScheduler.from_pretrained(config.pretrained_model_name_or_path, subfolder=\"scheduler\")\n", "with ContextManagers(deepspeed_zero_init_disabled_context_manager()):\n", " text_encoder = CLIPTextModel.from_pretrained(\n", " config.pretrained_model_name_or_path, subfolder=\"text_encoder\",\n", " )\n", " vae = AutoencoderKL.from_pretrained(\n", " config.pretrained_model_name_or_path, subfolder=\"vae\",\n", " )\n", "\n", "\n", "\n", "unet = UNet2DConditionModel(\n", " sample_size=config.image_size//8,\n", " cross_attention_dim = 768,\n", " )\n", "\n", "vae.requires_grad_(False)\n", "text_encoder.requires_grad_(False)\n", "optimizer = bnb.optim.AdamW8bit(\n", " unet.parameters(),\n", " lr=config.learning_rate,\n", " )\n", "lr_scheduler = get_scheduler(\n", " config.lr_scheduler,\n", " optimizer=optimizer,\n", " num_warmup_steps=config.lr_warmup_steps,\n", " num_training_steps=(len(train_dataloader)*config.num_epochs),\n", ")\n", "\n", "\n", "args = (config, unet, vae, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n", "\n", "\n" ] }, { "cell_type": "code", "source": [ "notebook_launcher(train_loop, args, num_processes=1)" ], "metadata": { "id": "GCR1zr9EKLyw" }, "execution_count": null, "outputs": [] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }