PortraitDiffusion / utils /style_attn_control.py
Jinl's picture
initial add
9bf9ce7
raw
history blame
No virus
13.1 kB
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from re import U
import numpy as np
from einops import rearrange
from .masactrl_utils import AttentionBase
from torchvision.utils import save_image
import sys
import torch
import torch.nn.functional as F
from torch import nn
import torch.fft as fft
from einops import rearrange, repeat
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
# from masactrl.masactrl import MutualSelfAttentionControl
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class AttentionBase:
def __init__(self):
self.cur_step = 0
self.num_att_layers = -1
self.cur_att_layer = 0
def after_step(self):
pass
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
# after step
self.after_step()
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
return out
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
class MaskPromptedStyleAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, style_attn_step=35, layer_idx=None, step_idx=None, total_steps=50, style_guidance=0.1,
only_masked_region=False, guidance=0.0,
style_mask=None, source_mask=None, de_bug=False):
"""
MaskPromptedSAC
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
thres: the thereshold for mask thresholding
ref_token_idx: the token index list for cross-attention map aggregation
cur_token_idx: the token index list for cross-attention map aggregation
mask_save_dir: the path to save the mask image
"""
super().__init__()
self.total_steps = total_steps
self.total_layers = 16
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
print("using MaskPromptStyleAttentionControl")
print("MaskedSAC at denoising steps: ", self.step_idx)
print("MaskedSAC at U-Net layers: ", self.layer_idx)
self.de_bug = de_bug
self.style_guidance = style_guidance
self.only_masked_region = only_masked_region
self.style_attn_step = style_attn_step
self.self_attns = []
self.cross_attns = []
self.guidance = guidance
self.style_mask = style_mask
self.source_mask = source_mask
def after_step(self):
self.self_attns = []
self.cross_attns = []
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
B = q.shape[0] // num_heads
H = W = int(np.sqrt(q.shape[1]))
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
if q_mask is not None:
sim = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
if k_mask is not None:
sim = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
attn = sim.softmax(-1) if attn is None else attn
if len(attn) == 2 * len(v):
v = torch.cat([v] * 2)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
return out
def attn_batch_fg_bg(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
B = q.shape[0] // num_heads
H = W = int(np.sqrt(q.shape[1]))
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
if q_mask is not None:
sim_fg = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
sim_bg = sim.masked_fill(q_mask.unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
if k_mask is not None:
sim_fg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
sim_bg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
sim = torch.cat([sim_fg, sim_bg])
attn = sim.softmax(-1)
if len(attn) == 2 * len(v):
v = torch.cat([v] * 2)
out = torch.einsum("h i j, h j d -> h i d", attn, v)
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
return out
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
B = q.shape[0] // num_heads // 2
H = W = int(np.sqrt(q.shape[1]))
if self.style_mask is not None and self.source_mask is not None:
#mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (4, H, W)
heigh, width = self.style_mask.shape[-2:]
mask_style = self.style_mask# (H, W)
mask_source = self.source_mask# (H, W)
scale = int(np.sqrt(heigh * width / q.shape[1]))
# res = int(np.sqrt(q.shape[1]))
spatial_mask_source = F.interpolate(mask_source, (heigh//scale, width//scale)).reshape(-1, 1)
spatial_mask_style = F.interpolate(mask_style, (heigh//scale, width//scale)).reshape(-1, 1)
else:
spatial_mask_source=None
spatial_mask_style=None
if spatial_mask_style is None or spatial_mask_source is None:
out_s,out_c,out_t = self.style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
else:
if self.only_masked_region:
out_s,out_c,out_t = self.mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
else:
out_s,out_c,out_t = self.separate_mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
out = torch.cat([out_s,out_c,out_t],dim=0)
return out
def style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
if self.de_bug:
import pdb; pdb.set_trace()
qs, qc, qt = q.chunk(3)
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
out_c = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
if self.cur_step < self.style_attn_step:
out_t = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
else:
out_t = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
if self.style_guidance>=0:
out_t = out_c + (out_t - out_c) * self.style_guidance
return out_s,out_c,out_t
def mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
qs, qc, qt = q.chunk(3)
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
out_c = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], attn[num_heads: 2*num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
out_c_new = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
if self.de_bug:
import pdb; pdb.set_trace()
if self.cur_step < self.style_attn_step:
out_t = out_c #self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
else:
out_t_fg = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
out_c_fg = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
if self.style_guidance>=0:
out_t = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance
out_t = out_t * spatial_mask_source + out_c * (1 - spatial_mask_source)
if self.de_bug:
import pdb; pdb.set_trace()
# print(torch.sum(out_t* (1 - spatial_mask_source) - out_c * (1 - spatial_mask_source)))
return out_s,out_c,out_t
def separate_mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
if self.de_bug:
import pdb; pdb.set_trace()
# To prevent query confusion, render fg and bg according to mask.
qs, qc, qt = q.chunk(3)
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
if self.cur_step < self.style_attn_step:
out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
out_c_fg,out_c_bg = out_c.chunk(2)
out_t = out_c_fg * spatial_mask_source + out_c_bg * (1 - spatial_mask_source)
else:
out_t = self.attn_batch_fg_bg(qt, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
out_t_fg,out_t_bg = out_t.chunk(2)
out_c_fg,out_c_bg = out_c.chunk(2)
if self.style_guidance>=0:
out_t_fg = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance
out_t_bg = out_c_bg + (out_t_bg - out_c_bg) * self.style_guidance
out_t = out_t_fg * spatial_mask_source + out_t_bg * (1 - spatial_mask_source)
return out_s,out_t,out_t