File size: 1,425 Bytes
e7d5680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from functools import partial

import torch

from opensora.registry import SCHEDULERS

from .dpm_solver import DPMS


@SCHEDULERS.register_module("dpm-solver")
class DMP_SOLVER:
    def __init__(self, num_sampling_steps=None, cfg_scale=4.0):
        self.num_sampling_steps = num_sampling_steps
        self.cfg_scale = cfg_scale

    def sample(
        self,
        model,
        text_encoder,
        z_size,
        prompts,
        device,
        additional_args=None,
    ):
        n = len(prompts)
        z = torch.randn(n, *z_size, device=device)
        model_args = text_encoder.encode(prompts)
        y = model_args.pop("y")
        null_y = text_encoder.null(n)
        if additional_args is not None:
            model_args.update(additional_args)

        dpms = DPMS(
            partial(forward_with_dpmsolver, model),
            condition=y,
            uncondition=null_y,
            cfg_scale=self.cfg_scale,
            model_kwargs=model_args,
        )
        samples = dpms.sample(z, steps=self.num_sampling_steps, order=2, skip_type="time_uniform", method="multistep")
        return samples


def forward_with_dpmsolver(self, x, timestep, y, **kwargs):
    """
    dpm solver donnot need variance prediction
    """
    # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
    model_out = self.forward(x, timestep, y, **kwargs)
    return model_out.chunk(2, dim=1)[0]