File size: 1,734 Bytes
08720f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from chat_anything.face_generator.pipelines.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline 

@torch.no_grad()
def generate(pipe, prompt, negative_prompt, **generating_conf):
    pipe_longprompt = StableDiffusionLongPromptWeightingPipeline(
        unet=pipe.unet,
        text_encoder=pipe.text_encoder,
        vae=pipe.vae,
        tokenizer=pipe.tokenizer,
        scheduler=pipe.scheduler, 
        safety_checker=None,
        feature_extractor=None,
    )
    print('generating: ', prompt)
    print('using negative prompt: ', negative_prompt)
    embeds = pipe_longprompt._encode_prompt(prompt=prompt, negative_prompt=negative_prompt, device=pipe.device, num_images_per_prompt=1, do_classifier_free_guidance=generating_conf['guidance_scale']>1,)
    negative_prompt_embeds, prompt_embeds = embeds.split(embeds.shape[0]//2)
    pipe_out = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        **generating_conf,
    )
    return pipe_out

if __name__ == '__main__':
    from diffusers.pipelines import StableDiffusionPipeline
    import argparse
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument(
            '--prompts',type=str,default=['starry night','Impression Sunrise, drawn by Claude Monet'], nargs='*'
        )
        
        args = parser.parse_args()
        prompts = args.prompts
        print(f'generating {prompts}')
        model_id = 'pretrained_model/sd-v1-4'
        pipe = StableDiffusionPipeline.from_pretrained(model_id,).to('cuda')
        images = pipe(prompts).images
        for i, image in enumerate(images):
            image.save(f'{prompts[i]}_{i}.png')

    main()