import torch.nn as nn from torchvision.models import resnet18, ResNet18_Weights import torch.nn.functional as F import torch class Model(nn.Module): def __init__(self): super().__init__() self.feature_extractor = resnet18(weights=ResNet18_Weights) in_channels = self.feature_extractor.fc.in_features self.feature_extractor.fc = nn.Identity() # Output is a vector of dimension 1 + 4 # 1 for probability of belonging to any class # 4 for bounding box of object that is presented (if no object is presented i. e. the probability < a threshold, any 4 numbers) self.fc_prob = nn.Sequential( nn.Linear(in_channels, 512), nn.Linear(512, 1) ) self.fc_bbox = nn.Sequential( nn.Linear(in_channels, 512), nn.Linear(512, 4) ) def forward(self, x): pred_prob = torch.sigmoid(self.fc_prob(self.feature_extractor(x))) pred_bbox = self.fc_bbox(self.feature_extractor(x)) return (pred_prob, pred_bbox)