ermu2001
init
08720f3
raw
history blame
No virus
1.73 kB
import torch
from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
@torch.no_grad()
def generate(pipe, prompt, negative_prompt, **generating_conf):
pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(
unet=pipe.unet,
text_encoder=pipe.text_encoder,
vae=pipe.vae,
tokenizer=pipe.tokenizer,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
)
print('generating: ', prompt)
print('using negative prompt: ', negative_prompt)
embeds = pipe_longprompt._encode_prompt(prompt=prompt, negative_prompt=negative_prompt, device=pipe.device, num_images_per_prompt=1, do_classifier_free_guidance=generating_conf['guidance_scale']>1,)
negative_prompt_embeds, prompt_embeds = embeds.split(embeds.shape[0]//2)
pipe_out = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
**generating_conf,
)
return pipe_out
if __name__ == '__main__':
from diffusers.pipelines import StableDiffusionPipeline
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--prompts',type=str,default=['starry night','Impression Sunrise, drawn by Claude Monet'], nargs='*'
)
args = parser.parse_args()
prompts = args.prompts
print(f'generating {prompts}')
model_id = 'pretrained_model/sd-v1-4'
pipe = StableDiffusionPipeline.from_pretrained(model_id,).to('cuda')
images = pipe(prompts).images
for i, image in enumerate(images):
image.save(f'{prompts[i]}_{i}.png')
main()