Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from torchvision.models import resnet18, ResNet18_Weights | |
import torch.nn.functional as F | |
import torch | |
class Model(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.feature_extractor = resnet18(weights=ResNet18_Weights) | |
in_channels = self.feature_extractor.fc.in_features | |
self.feature_extractor.fc = nn.Identity() | |
# Output is a vector of dimension 1 + 4 | |
# 1 for probability of belonging to any class | |
# 4 for bounding box of object that is presented (if no object is presented i. e. the probability < a threshold, any 4 numbers) | |
self.fc_prob = nn.Sequential( | |
nn.Linear(in_channels, 512), | |
nn.Linear(512, 1) | |
) | |
self.fc_bbox = nn.Sequential( | |
nn.Linear(in_channels, 512), | |
nn.Linear(512, 4) | |
) | |
def forward(self, x): | |
pred_prob = torch.sigmoid(self.fc_prob(self.feature_extractor(x))) | |
pred_bbox = self.fc_bbox(self.feature_extractor(x)) | |
return (pred_prob, pred_bbox) | |