|
import torch |
|
import transformers |
|
from torch import nn |
|
from transformers.modeling_outputs import SemanticSegmenterOutput |
|
|
|
|
|
def encode_down(c_in: int, c_out: int): |
|
return nn.Sequential( |
|
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(num_features=c_out), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels=c_out, out_channels=c_out, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(num_features=c_out), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
def decode_up(c: int): |
|
return nn.ConvTranspose2d( |
|
in_channels=c, |
|
out_channels=int(c / 2), |
|
kernel_size=2, |
|
stride=2, |
|
) |
|
|
|
|
|
class FaceUNet(nn.Module): |
|
def __init__(self, num_classes: int): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
|
|
self.down_1 = nn.Conv2d( |
|
in_channels=3, |
|
out_channels=64, |
|
kernel_size=3, |
|
padding=1, |
|
) |
|
self.down_2 = encode_down(64, 128) |
|
self.down_3 = encode_down(128, 256) |
|
self.down_4 = encode_down(256, 512) |
|
self.down_5 = encode_down(512, 1024) |
|
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.up_1 = decode_up(1024) |
|
self.up_c1 = encode_down(1024, 512) |
|
self.up_2 = decode_up(512) |
|
self.up_c2 = encode_down(512, 256) |
|
self.up_3 = decode_up(256) |
|
self.up_c3 = encode_down(256, 128) |
|
self.up_4 = decode_up(128) |
|
self.up_c4 = encode_down(128, 64) |
|
|
|
self.segment = nn.Conv2d( |
|
in_channels=64, |
|
out_channels=self.num_classes, |
|
kernel_size=3, |
|
padding=1, |
|
) |
|
|
|
def forward(self, x): |
|
d1 = self.down_1(x) |
|
d2 = self.pool(d1) |
|
d3 = self.down_2(d2) |
|
d4 = self.pool(d3) |
|
d5 = self.down_3(d4) |
|
d6 = self.pool(d5) |
|
d7 = self.down_4(d6) |
|
d8 = self.pool(d7) |
|
d9 = self.down_5(d8) |
|
|
|
u1 = self.up_1(d9) |
|
x = self.up_c1(torch.cat([d7, u1], 1)) |
|
u2 = self.up_2(x) |
|
x = self.up_c2(torch.cat([d5, u2], 1)) |
|
u3 = self.up_3(x) |
|
x = self.up_c3(torch.cat([d3, u3], 1)) |
|
u4 = self.up_4(x) |
|
x = self.up_c4(torch.cat([d1, u4], 1)) |
|
|
|
x = self.segment(x) |
|
return x |
|
|
|
|
|
class Segformer(transformers.PreTrainedModel): |
|
config_class = transformers.SegformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.model = FaceUNet(num_classes=config.num_classes) |
|
|
|
def forward(self, tensor): |
|
return self.model.forward_features(tensor) |
|
|
|
|
|
class SegformerForSemanticSegmentation(transformers.PreTrainedModel): |
|
config_class = transformers.SegformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.model = FaceUNet(num_classes=config.num_classes) |
|
|
|
def forward(self, pixel_values, labels=None): |
|
logits = self.model(pixel_values) |
|
values = {"logits": logits} |
|
if labels is not None: |
|
loss = torch.nn.cross_entropy(logits, labels) |
|
values["loss"] = loss |
|
return SemanticSegmenterOutput(**values) |
|
|