Edit model card

This model doesn't inherit huggingface/transformers so it needs to be downloaded

wget https://huggingface.co/Lancelot53/icon_classifier_maxvit/blob/main/best_model_89.pth

Inference Code

import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import torch.nn.functional as F

#load id_2_class.json
import json

id_2_class = {"0": "back", "1": "Briefcase", "2": "Call", "3": "Camera", "4": "Circle", "5": "Cloud", "6": "delete", "7": "Down", "8": "edit", "9": "Export", "10": "Face", "11": "Folder", "12": "Globe", "13": "Google", "14": "Heart", "15": "Home", "16": "Image", "17": "Import", "18": "Info", "19": "Link", "20": "Location", "21": "Mail", "22": "menu", "23": "Merge", "24": "Message", "25": "Microphone", "26": "more", "27": "Music", "28": "Mute", "29": "Person", "30": "Phone", "31": "plus", "32": "QRCODE", "33": "Refresh", "34": "search", "35": "settings", "36": "share", "37": "Star", "38": "Tick", "39": "Up", "40": "vidCam", "41": "Video", "42": "Volume"}
#make class_2_id dict

class_2_id = {}
for key, value in id_2_class.items():
    class_2_id[value] = key

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

class MaxViT(nn.Module):
    def __init__(self):
        super(MaxViT, self).__init__()
        model = models.maxvit_t(weights="DEFAULT")
        num_ftrs = model.classifier[5].in_features
        model.classifier[5] = nn.Linear(num_ftrs, len(class_2_id))
        self.model = model
    def forward(self, x):
        return self.model(x)

# Instantiate the model
model = MaxViT()
model.load_state_dict(torch.load('best_model_89.pth'))
model.eval()

def inference(image_path, CONFIDENT_THRESHOLD=None):
    img = Image.open(image_path).convert("L").convert("RGB")
    img = test_transform(img)
    img = img.unsqueeze(0)

    with torch.no_grad():
        output = F.softmax(model(img), dim=1)
        confidence, predicted = torch.max(output.data, 1)

    if CONFIDENT_THRESHOLD is not None and confidence.item() < CONFIDENT_THRESHOLD:
        return "UNKNOWN_CLASS", confidence.item()
    
    return id_2_class[str(predicted.item())], confidence.item()

inference("images/7820.jpg", 0.9) #0.9 should be good enough

Training

Check the repo

Dataset

Trained on Lancelot53/android_icon_dataset

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .