File size: 5,910 Bytes
7cf0db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import comfy


# Check and add 'model_patch' to model.model_options['transformer_options']
def add_model_patch_option(model):
    if 'transformer_options' not in model.model_options:
        model.model_options['transformer_options'] = {}
    to = model.model_options['transformer_options']
    if "model_patch" not in to:
        to["model_patch"] = {}
    return to


# Patch model with model_function_wrapper
def patch_model_function_wrapper(model, forward_patch, remove=False):

    def brushnet_model_function_wrapper(apply_model_method, options_dict):
        to = options_dict['c']['transformer_options']

        control = None
        if 'control' in options_dict['c']:
            control = options_dict['c']['control']
        
        x = options_dict['input']
        timestep = options_dict['timestep']
        
        # check if there are patches to execute
        if 'model_patch' not in to or 'forward' not in to['model_patch']:
            return apply_model_method(x, timestep, **options_dict['c'])

        mp = to['model_patch']
        unet = mp['unet']
        


        #print(model.get_model_object("model_sampling").sigmas, len(model.get_model_object("model_sampling").sigmas))
        #print(mp['all_sigmas'], len(mp['all_sigmas']))


        all_sigmas = mp['all_sigmas']
        sigma = to['sigmas'][0].item()
        total_steps = all_sigmas.shape[0] - 1
        step = torch.argmin((all_sigmas - sigma).abs()).item()

        mp['step'] = step
        mp['total_steps'] = total_steps

        # comfy.model_base.apply_model
        xc = model.model.model_sampling.calculate_input(timestep, x)
        if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
            xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
        t = model.model.model_sampling.timestep(timestep).float()
        # execute all patches        
        for method in mp['forward']:
            method(unet, xc, t, to, control)
            
        return apply_model_method(x, timestep, **options_dict['c'])
    
    if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
        print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"])
    model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)   

    to = add_model_patch_option(model)
    mp = to['model_patch']

    if isinstance(model.model.model_config, comfy.supported_models.SD15):
        mp['SDXL'] = False
    elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
        mp['SDXL'] = True
    else:
        print('Base model type: ', type(model.model.model_config))
        raise Exception("Unsupported model type: ", type(model.model.model_config))

    if 'forward' not in mp:
        mp['forward'] = []

    if remove:
        if forward_patch in mp['forward']:
            mp['forward'].remove(forward_patch)
    else:
        mp['forward'].append(forward_patch)

    mp['unet'] = model.model.diffusion_model
    mp['step'] = 0
    mp['total_steps'] = 1

    # apply patches to code
    if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
        comfy.samplers.original_sample = comfy.samplers.sample
        comfy.samplers.sample = modified_sample

    if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
        'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
        comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
        comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control


# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, 
           latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
    '''
    Modified by BrushNet nodes
    '''
    cfg_guider = comfy.samplers.CFGGuider(model)
    cfg_guider.set_conds(positive, negative)
    cfg_guider.set_cfg(cfg)

    ### Modified part ######################################################################
    #
    to = add_model_patch_option(model)
    to['model_patch']['all_sigmas'] = sigmas
    #
    #sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
    #sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
    #
    #
    #if math.isclose(cfg, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
    #    to['model_patch']['free_guidance'] = False
    #else:
    #    to['model_patch']['free_guidance'] = True
    #
    #######################################################################################
       
    return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)


# To use Controlnet with RAUNet it is much easier to modify apply_control a little
def modified_apply_control(h, control, name):
    '''
    Modified by BrushNet nodes
    '''
    if control is not None and name in control and len(control[name]) > 0:
        ctrl = control[name].pop()
        if ctrl is not None:
            if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
                ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device)                    
            try:
                h += ctrl
            except:
                print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
    return h