sejamenath2023's picture
Upload 12 files
239ee43
raw
history blame contribute delete
No virus
6.47 kB
import click
import torch
from pathlib import Path
import pkgutil
from imagen_pytorch import load_imagen_from_checkpoint
from imagen_pytorch.version import __version__
from imagen_pytorch.data import Collator
from imagen_pytorch.utils import safeget
from imagen_pytorch import ImagenTrainer, ElucidatedImagenConfig, ImagenConfig
from datasets import load_dataset
import json
def exists(val):
return val is not None
def simple_slugify(text, max_length = 255):
return text.replace('-', '_').replace(',', '').replace(' ', '_').replace('|', '--').strip('-_')[:max_length]
def main():
pass
@click.group()
def imagen():
pass
@imagen.command(help = 'Sample from the Imagen model checkpoint')
@click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
@click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
@click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
@click.argument('text')
def sample(
model,
cond_scale,
load_ema,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
# get version
version = safeget(loaded, 'version')
print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
# get imagen parameters and type
imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
imagen.cuda()
# generate image
pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True)
image_path = f'./{simple_slugify(text)}.png'
pil_image[0].save(image_path)
print(f'image saved to {str(image_path)}')
return
@imagen.command(help = 'Generate a config for the Imagen model')
@click.option('--path', default = './imagen_config.json', help = 'Path to the Imagen model config')
def config(
path
):
data = pkgutil.get_data(__name__, 'default_config.json').decode("utf-8")
with open(path, 'w') as f:
f.write(data)
@imagen.command(help = 'Train the Imagen model')
@click.option('--config', default = './imagen_config.json', help = 'Path to the Imagen model config')
@click.option('--unet', default = 1, help = 'Unet to train', type = click.IntRange(1, 3, False, True, True))
@click.option('--epoches', default = 1000, help = 'Amount of epoches to train for')
@click.option('--text', required = False, help = 'Text to sample with between epoches', type=str)
@click.option('--valid', is_flag = False, flag_value=50, default = 0, help = 'Do validation between epoches', show_default = True)
def train(
config,
unet,
epoches,
text,
valid
):
# check config path
config_path = Path(config)
full_config_path = str(config_path.resolve())
assert config_path.exists(), f'config not found at {full_config_path}'
with open(config_path, 'r') as f:
config_data = json.loads(f.read())
assert 'checkpoint_path' in config_data, 'checkpoint path not found in config'
model_path = Path(config_data['checkpoint_path'])
full_model_path = str(model_path.resolve())
# setup imagen config
imagen_config_klass = ElucidatedImagenConfig if config_data['type'] == 'elucidated' else ImagenConfig
imagen = imagen_config_klass(**config_data['imagen']).create()
trainer = ImagenTrainer(
imagen = imagen,
**config_data['trainer']
)
# load pt
if model_path.exists():
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
trainer.load(model_path)
if torch.cuda.is_available():
trainer = trainer.cuda()
size = config_data['imagen']['image_sizes'][unet-1]
max_batch_size = config_data['max_batch_size'] if 'max_batch_size' in config_data else 1
channels = 'RGB'
if 'channels' in config_data['imagen']:
assert config_data['imagen']['channels'] > 0 and config_data['imagen']['channels'] < 5, 'Imagen only support 1 to 4 channels L, LA, RGB, RGBA'
if config_data['imagen']['channels'] == 4:
channels = 'RGBA' # Color with alpha
elif config_data['imagen']['channels'] == 2:
channels == 'LA' # Luminance (Greyscale) with alpha
elif config_data['imagen']['channels'] == 1:
channels = 'L' # Luminance (Greyscale)
assert 'batch_size' in config_data['dataset'], 'A batch_size is required in the config file'
# load and add train dataset and valid dataset
ds = load_dataset(config_data['dataset_name'])
trainer.add_train_dataset(
ds = ds['train'],
collate_fn = Collator(
image_size = size,
image_label = config_data['image_label'],
text_label = config_data['text_label'],
url_label = config_data['url_label'],
name = imagen.text_encoder_name,
channels = channels
),
**config_data['dataset']
)
if not trainer.split_valid_from_train and valid != 0:
assert 'valid' in ds, 'There is no validation split in the dataset'
trainer.add_valid_dataset(
ds = ds['valid'],
collate_fn = Collator(
image_size = size,
image_label = config_data['image_label'],
text_label= config_data['text_label'],
url_label = config_data['url_label'],
name = imagen.text_encoder_name,
channels = channels
),
**config_data['dataset']
)
for i in range(epoches):
loss = trainer.train_step(unet_number = unet, max_batch_size = max_batch_size)
print(f'loss: {loss}')
if valid != 0 and not (i % valid) and i > 0:
valid_loss = trainer.valid_step(unet_number = unet, max_batch_size = max_batch_size)
print(f'valid loss: {valid_loss}')
if not (i % 100) and i > 0 and trainer.is_main and text is not None:
images = trainer.sample(texts = [text], batch_size = 1, return_pil_images = True, stop_at_unet_number = unet)
images[0].save(f'./sample-{i // 100}.png')
trainer.save(model_path)