|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
import os |
|
import numpy as np |
|
from PIL import Image |
|
from copy import deepcopy |
|
from collections import defaultdict |
|
|
|
from detectron2.structures import ImageList |
|
from detectron2.utils.comm import get_local_rank |
|
from modeling.semantic_enhanced_matting.predictor import SamPredictor |
|
from modeling.semantic_enhanced_matting.condition_conv import ConditionConv, ConditionEmbedding, ConditionAdd, BBoxEmbedInteract, BBoxInteract, BBoxInteractInOut |
|
from modeling.semantic_enhanced_matting.modeling.image_encoder import PatchEmbed |
|
from modeling.semantic_enhanced_matting.modeling.common import LayerNorm2d |
|
from modeling.decoder.unet_detail_capture import MattingDetailDecoder |
|
from modeling.semantic_enhanced_matting.feature_fusion import FeatureFusion |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
from modeling.semantic_enhanced_matting.modeling.mask_decoder_hq_matting import MaskDecoderHQMatting |
|
from modeling.semantic_enhanced_matting.modeling import TwoWayTransformer |
|
|
|
from peft import LoraConfig, get_peft_model |
|
from peft.tuners.lora.layer import LoraLayer |
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
|
from data.rand_augment import RandAugment |
|
import random |
|
import kornia.filters as kf |
|
|
|
|
|
class SamHqMatte(nn.Module): |
|
|
|
target_length = 1024 |
|
|
|
def __init__( |
|
self, |
|
*, |
|
sam_model, |
|
hq_token_only, |
|
hq_features_type, |
|
matting_decoder, |
|
criterion, |
|
pixel_mean, |
|
pixel_std, |
|
multimask_output=False, |
|
vis_period=None, |
|
output_dir=None, |
|
lora_rank = None, |
|
lora_alpha = None, |
|
lora_target_modules = ["qkv", "proj"], |
|
lora_dropout = 0.1, |
|
w_dora = False, |
|
w_rslora = False, |
|
lora_on_mask_decoder = False, |
|
frozen_sam_hq_reg = None, |
|
reg_margin = 0.85, |
|
w_attention_mask = False, |
|
alpha_reg_range = None, |
|
alpha_reg_weight = 1.0, |
|
coconut_pl = False, |
|
coconut_pl_alpha = 1.0, |
|
coconut_self_training = False, |
|
eval_w_sam_hq_mask = False, |
|
backbone_condition = False, |
|
condition_wo_conv = False, |
|
w_only_bbox_cond = False, |
|
coconut_only_known_l1 = False, |
|
backbone_bbox_prompt = None, |
|
backbone_bbox_prompt_loc = [2, 3], |
|
backbone_bbox_prompt_loss_weight = 1.0, |
|
concat_gen_trimap = False, |
|
multi_matting_decoder = None, |
|
w_all_logits = False, |
|
bbox_prompt_all_block = None, |
|
matting_token = False, |
|
test_w_hq_token = False, |
|
sam_hq_token_reg = None, |
|
feat_cross_attn_fusion = False, |
|
trimap_loss_type = None, |
|
reg_on_sam_logits = False, |
|
reg_w_bce_loss = False, |
|
complex_trimap_pred_layer = False, |
|
matting_token_sup = None, |
|
matting_token_sup_loss_weight = None, |
|
sam2 = False, |
|
): |
|
super(SamHqMatte, self).__init__() |
|
|
|
self.sam_model = sam_model |
|
self.sam_predictor = SamPredictor(self.sam_model) if not sam2 else SAM2ImagePredictor(self.sam_model) |
|
self.hq_token_only = hq_token_only |
|
self.multimask_output = multimask_output |
|
self.hq_features_type = hq_features_type |
|
|
|
self.matting_decoder = matting_decoder |
|
|
|
self.criterion = criterion |
|
|
|
self.register_buffer( |
|
"pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False |
|
) |
|
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) |
|
assert ( |
|
self.pixel_mean.shape == self.pixel_std.shape |
|
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" |
|
|
|
self.vis_period = vis_period |
|
if output_dir is not None and output_dir != '?': |
|
self.output_dir = os.path.join(output_dir, 'vis_results') |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
self.train_iter_index = 0 |
|
|
|
self.lora_rank = lora_rank |
|
self.lora_alpha = lora_alpha |
|
self.lora_target_modules = lora_target_modules |
|
self.lora_dropout = lora_dropout |
|
self.w_dora = w_dora |
|
self.w_rslora = w_rslora |
|
self.lora_on_mask_decoder = lora_on_mask_decoder |
|
self.frozen_sam_hq_reg = frozen_sam_hq_reg |
|
self.reg_margin = reg_margin |
|
self.w_attention_mask = w_attention_mask |
|
self.alpha_reg_range = alpha_reg_range |
|
self.alpha_reg_weight = alpha_reg_weight |
|
self.coconut_pl = coconut_pl |
|
self.coconut_pl_alpha = coconut_pl_alpha |
|
self.coconut_self_training = coconut_self_training |
|
self.eval_w_sam_hq_mask = eval_w_sam_hq_mask |
|
self.backbone_condition = backbone_condition |
|
self.condition_wo_conv = condition_wo_conv |
|
self.w_only_bbox_cond = w_only_bbox_cond |
|
self.coconut_only_known_l1 = coconut_only_known_l1 |
|
self.backbone_bbox_prompt = backbone_bbox_prompt |
|
self.backbone_bbox_prompt_loc = backbone_bbox_prompt_loc |
|
self.backbone_bbox_prompt_loss_weight = backbone_bbox_prompt_loss_weight |
|
self.concat_gen_trimap = concat_gen_trimap |
|
self.multi_matting_decoder = multi_matting_decoder |
|
self.w_all_logits = w_all_logits |
|
self.bbox_prompt_all_block = bbox_prompt_all_block |
|
self.matting_token = matting_token |
|
self.test_w_hq_token = test_w_hq_token |
|
self.sam_hq_token_reg = sam_hq_token_reg |
|
self.feat_cross_attn_fusion = feat_cross_attn_fusion |
|
self.trimap_loss_type = trimap_loss_type |
|
self.reg_on_sam_logits = reg_on_sam_logits |
|
self.reg_w_bce_loss = reg_w_bce_loss |
|
self.complex_trimap_pred_layer = complex_trimap_pred_layer |
|
self.matting_token_sup = matting_token_sup |
|
self.sam2 = sam2 |
|
assert self.matting_token_sup in {'alpha', 'trimap', None} |
|
self.matting_token_sup_loss_weight = matting_token_sup_loss_weight |
|
if self.matting_token_sup is not None: |
|
assert self.backbone_bbox_prompt in {'bbox', None} |
|
if self.frozen_sam_hq_reg is not None: |
|
assert self.lora_rank is not None |
|
if self.w_attention_mask: |
|
self.attention_head = deepcopy(self.matting_decoder) |
|
if self.coconut_self_training: |
|
self.rand_aug = RandAugment(3,6) |
|
self.warm_iter_coconut_self_training = 5000 |
|
if self.backbone_condition: |
|
assert self.lora_rank is not None |
|
if self.backbone_bbox_prompt is not None: |
|
assert self.lora_rank is not None |
|
if self.w_all_logits: |
|
self.sam_predictor.model.mask_decoder.w_all_logits = True |
|
if self.bbox_prompt_all_block: |
|
assert self.lora_rank is not None |
|
if self.matting_token and not self.sam2: |
|
self.sam_predictor.model.mask_decoder.hq_token_only = self.hq_token_only |
|
|
|
@property |
|
def device(self): |
|
return self.pixel_mean.device |
|
|
|
def init_lora(self, model=None): |
|
if model is not None and self.lora_rank >= 1: |
|
if self.lora_on_mask_decoder: |
|
self.lora_target_modules += ["q_proj", "k_proj", "v_proj", "out_proj"] |
|
modules_to_save = None |
|
else: |
|
modules_to_save = ['matting_decoder'] |
|
|
|
lora_config = LoraConfig( |
|
r=self.lora_rank, |
|
lora_alpha=self.lora_alpha, |
|
use_rslora=self.w_rslora, |
|
use_dora=self.w_dora, |
|
init_lora_weights="gaussian", |
|
target_modules=self.lora_target_modules, |
|
lora_dropout=self.lora_dropout, |
|
modules_to_save=modules_to_save |
|
) |
|
model = get_peft_model(model, lora_config) |
|
if self.lora_on_mask_decoder: |
|
for n, p in model.matting_decoder.named_parameters(): |
|
if n.split('modules_to_save.default.')[-1] in model.matting_decoder.trainable_params_str: |
|
p.requires_grad = True |
|
else: |
|
for n, p in model.matting_decoder.named_parameters(): |
|
if n.split('modules_to_save.default.')[-1] in model.matting_decoder.frozen_params_str: |
|
p.requires_grad = False |
|
return model |
|
elif self.lora_rank >= 1: |
|
lora_config = LoraConfig( |
|
r=self.lora_rank, |
|
lora_alpha=self.lora_alpha, |
|
use_rslora=self.w_rslora, |
|
use_dora=self.w_dora, |
|
init_lora_weights="gaussian", |
|
target_modules=self.lora_target_modules, |
|
lora_dropout=self.lora_dropout, |
|
) |
|
self.sam_predictor.model.image_encoder = get_peft_model(self.sam_predictor.model.image_encoder, lora_config) |
|
|
|
if self.sam2: |
|
for n, p in self.sam_predictor.model.image_encoder.named_parameters(): |
|
if 'bbox_mask' in n: |
|
p.requires_grad = True |
|
|
|
if self.backbone_condition: |
|
if self.w_only_bbox_cond: |
|
self.condition_embedding = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 160) |
|
else: |
|
self.condition_embedding = ConditionEmbedding(condition_num = 5, pos_embedding_dim = 128) |
|
|
|
if self.condition_wo_conv: |
|
self.condition_conv = nn.ModuleList([ConditionAdd() for _ in range(4)]) |
|
else: |
|
self.condition_conv = nn.ModuleList([ConditionConv( |
|
in_channels = self.sam_predictor.model.image_encoder.embed_dim, |
|
out_channels = self.sam_predictor.model.image_encoder.embed_dim, |
|
bottleneck_channels = 512 |
|
) for _ in range(4)]) |
|
|
|
if self.backbone_bbox_prompt is not None and not self.sam2: |
|
self.condition_layer = nn.ModuleDict() |
|
self.condition_layer['patch_embed'] = PatchEmbed( |
|
kernel_size=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size), |
|
stride=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size), |
|
in_chans=4, |
|
embed_dim=self.sam_predictor.model.image_encoder.embed_dim, |
|
) |
|
if self.multi_matting_decoder is None: |
|
if self.backbone_bbox_prompt in {'trimap', 'alpha_trimap'}: |
|
transformer_dim = self.sam_predictor.model.image_encoder.embed_dim |
|
for i in self.backbone_bbox_prompt_loc: |
|
if self.complex_trimap_pred_layer: |
|
self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( |
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2), |
|
LayerNorm2d(transformer_dim // 2), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1), |
|
LayerNorm2d(transformer_dim // 4), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
|
LayerNorm2d(transformer_dim // 8), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1), |
|
LayerNorm2d(transformer_dim // 16), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 16, 3, kernel_size=3, stride=1, padding=1), |
|
) |
|
else: |
|
self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( |
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), |
|
LayerNorm2d(transformer_dim // 4), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 8, 3, kernel_size=1, stride=1), |
|
) |
|
elif self.backbone_bbox_prompt == 'alpha': |
|
transformer_dim = self.sam_predictor.model.image_encoder.embed_dim |
|
for i in self.backbone_bbox_prompt_loc: |
|
if self.complex_trimap_pred_layer: |
|
self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( |
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2), |
|
LayerNorm2d(transformer_dim // 2), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1), |
|
LayerNorm2d(transformer_dim // 4), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
|
LayerNorm2d(transformer_dim // 8), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1), |
|
LayerNorm2d(transformer_dim // 16), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 16, 1, kernel_size=3, stride=1, padding=1), |
|
nn.Sigmoid() |
|
) |
|
else: |
|
self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential( |
|
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), |
|
LayerNorm2d(transformer_dim // 4), |
|
nn.GELU(), |
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), |
|
nn.GELU(), |
|
nn.Conv2d(transformer_dim // 8, 1, kernel_size=1, stride=1), |
|
nn.Sigmoid() |
|
) |
|
if self.bbox_prompt_all_block is not None: |
|
if self.bbox_prompt_all_block == 'reuse_cross-self-attn': |
|
self.condition_layer['prompt_layer'] = BBoxInteract( |
|
position_point_embedding = deepcopy(self.sam_predictor.model.prompt_encoder.pe_layer), |
|
point_weight = deepcopy(self.sam_predictor.model.prompt_encoder.point_embeddings) |
|
) |
|
elif self.bbox_prompt_all_block == 'in-out-bbox_cross-self-attn': |
|
self.condition_layer['prompt_layer'] = BBoxInteractInOut(downsample_rate = 2) |
|
else: |
|
embed_type, interact_type = self.bbox_prompt_all_block.split('_') |
|
self.condition_layer['prompt_layer'] = BBoxEmbedInteract(embed_type, interact_type) |
|
|
|
if self.feat_cross_attn_fusion: |
|
self.condition_layer['feature_fusion'] = FeatureFusion(in_channels=self.sam_predictor.model.image_encoder.embed_dim, attn_compression_ratio=8) |
|
|
|
def condition_bbox_and_instance_num(self): |
|
self.sam_predictor.model.image_encoder.conv_necks = None |
|
|
|
def forward_samhq_and_matting_decoder(self, images, bbox, condition_proj=None, return_hq_token=False): |
|
|
|
if self.sam2: |
|
interm_features, sam2_logits, matting_logits, pred_trimap = self.forward_samhq(images, bbox, condition_proj) |
|
sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
matting_logits = F.interpolate(matting_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
sam_hq_matting_token = { |
|
'masks_hq': sam2_logits, |
|
'masks_matting': matting_logits |
|
} |
|
hq_features = matting_logits |
|
low_res_masks = matting_logits |
|
else: |
|
if self.matting_token: |
|
features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token, pred_trimap = self.forward_samhq(images, bbox, condition_proj) |
|
if return_hq_token: |
|
return sam_hq_matting_token['masks_hq'] |
|
else: |
|
if not self.training and self.test_w_hq_token: |
|
low_res_masks, hq_features = sam_hq_matting_token['masks_hq'], sam_hq_matting_token['masks_hq'] |
|
else: |
|
low_res_masks, hq_features = sam_hq_matting_token['masks_matting'], sam_hq_matting_token['masks_matting'] |
|
else: |
|
features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks, pred_trimap = self.forward_samhq(images, bbox, condition_proj) |
|
if return_hq_token: |
|
return hq_features |
|
sam_hq_matting_token = {'masks_hq': hq_features, 'masks_sam': sam_logits} |
|
|
|
|
|
if isinstance(self.matting_decoder, MattingDetailDecoder): |
|
pred_alpha = self.matting_decoder( |
|
images = images, |
|
hq_features = hq_features, |
|
vit_intern_feat = interm_features, |
|
return_alpha_logits = (self.alpha_reg_range is not None), |
|
pred_trimap = pred_trimap |
|
) |
|
else: |
|
pred_alpha = self.matting_decoder( |
|
image_embeddings = features, |
|
image_pe = image_pe, |
|
sparse_prompt_embeddings = sparse_embeddings, |
|
dense_prompt_embeddings = dense_embeddings, |
|
multimask_output = False, |
|
interm_embeddings = interm_features, |
|
hq_features = hq_features, |
|
images = images, |
|
return_alpha_logits = (self.alpha_reg_range is not None), |
|
pred_trimap = pred_trimap |
|
) |
|
return low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token |
|
|
|
def forward(self, batched_inputs): |
|
|
|
inputs = self.preprocess_inputs(batched_inputs) |
|
images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition'] |
|
|
|
if self.backbone_condition: |
|
condition_proj = self.condition_embedding(condition) |
|
elif self.backbone_bbox_prompt is not None or self.bbox_prompt_all_block is not None: |
|
condition_proj = bbox |
|
else: |
|
condition_proj = None |
|
|
|
low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token = self.forward_samhq_and_matting_decoder(images, bbox, condition_proj) |
|
|
|
assert not self.training |
|
if self.eval_w_sam_hq_mask: |
|
self.sam_predictor.model.image_encoder.disable_adapter_layers() |
|
with torch.no_grad(): |
|
ori_features, ori_interm_features = self.sam_predictor.model.image_encoder(images) |
|
samhq_low_res_masks = self.forward_samhq_others(images, bbox, ori_features, ori_interm_features)[-1] |
|
samhq_low_res_masks = F.interpolate(samhq_low_res_masks, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False) |
|
self.sam_predictor.model.image_encoder.enable_adapter_layers() |
|
|
|
return pred_alpha, samhq_low_res_masks |
|
else: |
|
return pred_alpha |
|
|
|
def forward_samhq_image_encoder(self, images, condition_proj=None): |
|
if self.sam2: |
|
backbone_out = self.sam_predictor.model.forward_image([images, condition_proj]) |
|
_, vision_feats, _, _ = self.sam_predictor.model._prepare_backbone_features(backbone_out) |
|
|
|
if self.sam_predictor.model.directly_add_no_mem_embed: |
|
vision_feats[-1] = vision_feats[-1] + self.sam_predictor.model.no_mem_embed |
|
feats = [ |
|
feat.permute(1, 2, 0).view(feat.shape[1], -1, *feat_size) |
|
for feat, feat_size in zip(vision_feats[::-1], self.sam_predictor._bb_feat_sizes[::-1]) |
|
][::-1] |
|
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}, None, None |
|
else: |
|
if self.backbone_condition: |
|
condition_layer = self.condition_conv |
|
elif self.backbone_bbox_prompt: |
|
condition_layer = self.condition_layer |
|
else: |
|
condition_layer = None |
|
|
|
features, interm_features, pred_trimap = self.sam_predictor.model.image_encoder(images, condition_proj, condition_layer) |
|
return features, interm_features, pred_trimap |
|
|
|
|
|
def forward_samhq_others(self, images, bbox, features, interm_features): |
|
if self.sam2: |
|
sam2_logits, matting_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features) |
|
return features, sam2_logits, matting_logits |
|
|
|
image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe() |
|
|
|
cat_sparse_embeddings = [] |
|
cat_dense_prompt_embeddings = [] |
|
cat_hq_features = [] |
|
cat_sam_logits = [] |
|
cat_low_res_masks = [] |
|
cat_sam_hq_matting_token = defaultdict(list) |
|
|
|
for idx in range(images.shape[0]): |
|
|
|
|
|
|
|
sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder( |
|
points=None, |
|
|
|
boxes=bbox[idx], |
|
masks=None, |
|
) |
|
|
|
|
|
if isinstance(self.sam_predictor.model.mask_decoder, MaskDecoderHQMatting): |
|
sam_hq_matting_token = self.sam_predictor.model.mask_decoder( |
|
image_embeddings = features[idx: idx + 1], |
|
image_pe = image_pe, |
|
sparse_prompt_embeddings = sparse_embeddings, |
|
dense_prompt_embeddings = dense_embeddings, |
|
multimask_output = self.multimask_output, |
|
interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features], |
|
) |
|
for key in sam_hq_matting_token.keys(): |
|
cat_sam_hq_matting_token[key].append(sam_hq_matting_token[key]) |
|
else: |
|
low_res_masks, masks_sam, hq_features = self.sam_predictor.model.mask_decoder( |
|
image_embeddings = features[idx: idx + 1], |
|
image_pe = image_pe, |
|
sparse_prompt_embeddings = sparse_embeddings, |
|
dense_prompt_embeddings = dense_embeddings, |
|
multimask_output = self.multimask_output, |
|
hq_token_only = self.hq_token_only, |
|
interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features], |
|
return_hq_features_type = self.hq_features_type |
|
) |
|
cat_hq_features.append(hq_features) |
|
cat_sam_logits.append(masks_sam) |
|
cat_low_res_masks.append(low_res_masks) |
|
|
|
cat_sparse_embeddings.append(sparse_embeddings) |
|
cat_dense_prompt_embeddings.append(dense_embeddings) |
|
|
|
sparse_embeddings = torch.stack(cat_sparse_embeddings, dim=0) |
|
dense_embeddings = torch.stack(cat_dense_prompt_embeddings, dim=0) |
|
|
|
if self.matting_token: |
|
for key in cat_sam_hq_matting_token.keys(): |
|
cat_sam_hq_matting_token[key] = torch.cat(cat_sam_hq_matting_token[key], dim=0) |
|
cat_sam_hq_matting_token[key] = F.interpolate(cat_sam_hq_matting_token[key], size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
sam_hq_matting_token = cat_sam_hq_matting_token |
|
return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token |
|
else: |
|
hq_features = torch.cat(cat_hq_features, dim=0) |
|
low_res_masks = torch.cat(cat_low_res_masks, dim=0) |
|
hq_features = F.interpolate(hq_features, size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
sam_logits = torch.cat(cat_sam_logits, dim=0) |
|
sam_logits = F.interpolate(sam_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks |
|
|
|
def forward_samhq(self, images, bbox, condition_proj=None): |
|
if self.lora_rank is None: |
|
with torch.no_grad(): |
|
features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj) |
|
else: |
|
features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj) |
|
|
|
return self.forward_samhq_others(images, bbox, features, interm_features) + (pred_trimap, ) |
|
|
|
def get_frozen_sam_logits(self, images, bbox, mask_type='hq'): |
|
|
|
if self.sam2: |
|
features, _, _ = self.forward_samhq_image_encoder(images) |
|
sam2_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features, wo_matting_token=True) |
|
sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
return sam2_logits |
|
|
|
assert mask_type in {'hq', 'sam'} |
|
features, interm_features, _ = self.forward_samhq_image_encoder(images) |
|
image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe() |
|
|
|
cat_logits = [] |
|
for idx in range(images.shape[0]): |
|
sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder(points=None, boxes=bbox[idx], masks=None) |
|
|
|
low_res_masks, masks_sam, hq_features = self.sam_predictor.model.frozen_mask_decoder( |
|
image_embeddings = features[idx: idx + 1], |
|
image_pe = image_pe, |
|
sparse_prompt_embeddings = sparse_embeddings, |
|
dense_prompt_embeddings = dense_embeddings, |
|
multimask_output = self.multimask_output, |
|
hq_token_only = self.hq_token_only, |
|
interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features], |
|
return_hq_features_type = self.hq_features_type |
|
) |
|
if mask_type == 'hq': |
|
cat_logits.append(hq_features) |
|
else: |
|
cat_logits.append(masks_sam) |
|
|
|
logits = torch.cat(cat_logits, dim=0) |
|
logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False) |
|
return logits |
|
|
|
def vis_training_results(self, **kwargs): |
|
|
|
self.train_iter_index += 1 |
|
if self.train_iter_index % self.vis_period == 0: |
|
batch_save_results = [] |
|
save_path = os.path.join(self.output_dir, '{:06d}_rank{}.jpg'.format(self.train_iter_index, get_local_rank())) |
|
|
|
|
|
for key in kwargs.keys(): |
|
if key == 'bbox': |
|
continue |
|
|
|
if key == 'images': |
|
kwargs[key] = kwargs[key] * self.pixel_std + self.pixel_mean |
|
kwargs[key] = kwargs[key].permute(0, 2, 3, 1) * 255.0 |
|
for i in range(kwargs['images'].shape[0]): |
|
l, u, r, d = int(kwargs['bbox'][i, 0, 0].item()), int(kwargs['bbox'][i, 0, 1].item()), int(kwargs['bbox'][i, 0, 2].item()), int(kwargs['bbox'][i, 0, 3].item()) |
|
red_line = torch.tensor([[255., 0., 0.]], device=kwargs[key].device, dtype=kwargs[key].dtype) |
|
kwargs[key][i, u: d, l, :] = red_line |
|
kwargs[key][i, u: d, r, :] = red_line |
|
kwargs[key][i, u, l: r, :] = red_line |
|
kwargs[key][i, d, l: r, :] = red_line |
|
elif key in {'low_res_masks', 'frozen_hq_token'}: |
|
if torch.max(kwargs[key]) <= 1: |
|
kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0 |
|
else: |
|
kwargs[key] = F.interpolate(kwargs[key], size=(kwargs['images'].shape[-3], kwargs['images'].shape[-2]), mode='bilinear', align_corners=False) |
|
kwargs[key] = (kwargs[key] > self.sam_predictor.model.mask_threshold).float().permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0 |
|
else: |
|
kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0 |
|
|
|
kwargs[key] = np.uint8(kwargs[key].detach().cpu().numpy()) |
|
|
|
for i in range(kwargs['images'].shape[0]): |
|
save_results = [] |
|
for key in kwargs.keys(): |
|
if key != 'bbox': |
|
save_results.append(kwargs[key][i]) |
|
batch_save_results.append(np.concatenate(save_results, axis=1)) |
|
|
|
Image.fromarray(np.concatenate(batch_save_results, axis=0)).save(save_path) |
|
|
|
def preprocess_inputs(self, batched_inputs): |
|
""" |
|
Normalize, pad and batch the input images. |
|
""" |
|
output = dict() |
|
|
|
if "alpha" in batched_inputs: |
|
alpha = batched_inputs["alpha"].to(self.device) |
|
else: |
|
alpha = None |
|
|
|
bbox = batched_inputs["bbox"].to(self.device) |
|
|
|
if self.training and self.coconut_self_training and sum([i == 'COCONut' for i in batched_inputs['dataset_name']]) >= 1: |
|
output['coconut_ori_img'] = [] |
|
output['coconut_trimap'] = [] |
|
output['coconut_bbox'] = [] |
|
output['coconut_idx'] = [] |
|
for i, dataset_name in enumerate(batched_inputs['dataset_name']): |
|
if dataset_name == 'COCONut': |
|
|
|
img_np = np.uint8(batched_inputs["image"][i].permute(1, 2, 0).cpu().numpy() * 255.) |
|
strong_aug_img = self.rand_aug(Image.fromarray(img_np), cutout = False) |
|
strong_aug_img_tensor = torch.from_numpy(np.array(strong_aug_img)).to(self.device).permute(2, 0, 1)[None] / 255. |
|
blur_kernel_sigma = 1.0 + random.random() |
|
blur_filter = kf.GaussianBlur2d((101, 101), (blur_kernel_sigma, blur_kernel_sigma)) |
|
blur_strong_aug_img_tensor = blur_filter(strong_aug_img_tensor)[0] |
|
|
|
output['coconut_ori_img'].append(batched_inputs["image"][i]) |
|
batched_inputs["image"][i] = blur_strong_aug_img_tensor |
|
|
|
|
|
coconut_mask = (alpha[i] != 0).float() |
|
mask_area = torch.sum(coconut_mask) |
|
kernel_size = max(self.matting_decoder.min_kernel_size, int((mask_area ** 0.5) / 7)) |
|
kernel_size = min(kernel_size, self.matting_decoder.gen_trimap.max_kernal - 1) |
|
output['coconut_trimap'].append(self.matting_decoder.gen_trimap(coconut_mask[0], kernel_size=kernel_size)[None]) |
|
|
|
output['coconut_bbox'].append(bbox[i]) |
|
output['coconut_idx'].append(i) |
|
|
|
output['coconut_ori_img'] = torch.stack(output['coconut_ori_img']).to(self.device) |
|
output['coconut_ori_img'] = (output['coconut_ori_img'] - self.pixel_mean) / self.pixel_std |
|
output['coconut_trimap'] = torch.stack(output['coconut_trimap']).to(self.device) |
|
output['coconut_bbox'] = torch.stack(output['coconut_bbox']).to(self.device) |
|
|
|
images = batched_inputs["image"].to(self.device) |
|
images = (images - self.pixel_mean) / self.pixel_std |
|
assert images.shape[-2] == images.shape[-1] == 1024 |
|
|
|
if 'trimap' in batched_inputs.keys(): |
|
trimap = batched_inputs["trimap"].to(self.device) |
|
assert len(torch.unique(trimap)) <= 3 |
|
else: |
|
trimap = None |
|
|
|
output['images'] = images |
|
output['bbox'] = bbox |
|
output['alpha'] = alpha |
|
output['trimap'] = trimap |
|
|
|
if 'hr_images' in batched_inputs.keys(): |
|
hr_images = batched_inputs["hr_images"].to(self.device) |
|
hr_images = (hr_images - self.pixel_mean) / self.pixel_std |
|
_, _, H, W = hr_images.shape |
|
if hr_images.shape[-1] % 16 != 0 or hr_images.shape[-2] % 16 != 0: |
|
new_H = (16 - hr_images.shape[-2] % 16) + H if hr_images.shape[-2] % 16 != 0 else H |
|
new_W = (16 - hr_images.shape[-1] % 16) + W if hr_images.shape[-1] % 16 != 0 else W |
|
new_hr_images = torch.zeros((hr_images.shape[0], hr_images.shape[1], new_H, new_W)).to(self.device) |
|
new_hr_images[:,:,:H,:W] = hr_images[:,:,:,:] |
|
del hr_images |
|
hr_images = new_hr_images |
|
output['hr_images'] = hr_images |
|
output['hr_images_ori_h_w'] = (H, W) |
|
|
|
if 'dataset_name' in batched_inputs.keys(): |
|
output['dataset_name'] = batched_inputs["dataset_name"] |
|
|
|
if self.backbone_condition: |
|
if self.w_only_bbox_cond: |
|
output['condition'] = output['bbox'][:, 0, :] |
|
else: |
|
multi_fg_float = batched_inputs["multi_fg"].to(bbox.device).float()[:, None] * 512 |
|
output['condition'] = torch.concat((output['bbox'][:, 0, :], multi_fg_float), dim=-1) |
|
else: |
|
output['condition'] = None |
|
|
|
return output |
|
|