Spaces:
Sleeping
Sleeping
File size: 2,247 Bytes
7a9cb7c 0053d00 c08cd12 7a9cb7c 94e002e 1681aee 7a9cb7c 1681aee 7a9cb7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torchvision
import gradio as gr
import pathlib
import random
from torch import nn
from typing import Tuple, Dict
from PIL import Image
from timeit import default_timer as timer
from typing import Tuple, Dict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
with open('class-names.txt', 'r') as f:
class_names = f.read().split('\n')[:-1]
def load_model() -> Tuple[torch.nn.Module, torchvision.transforms.Compose]:
weights = torchvision.models.ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1
shufflenet_transforms = weights.transforms()
shufflenet = torchvision.models.shufflenet_v2_x1_5(weights=weights)
shufflenet.fc = nn.Linear(in_features=1024, out_features=len(class_names), bias=True)
state_dict = torch.load('ShuffleNetV2.pt', map_location=device)
shufflenet.load_state_dict(state_dict)
return shufflenet, shufflenet_transforms
model, transforms = load_model()
def predict(img) -> Tuple[Dict, float]:
start = timer()
model.to(device)
model.eval()
with torch.inference_mode():
transformed_img = transforms(img).to(device)
logits = model(transformed_img.unsqueeze(dim=0))
pred_prob = torch.softmax(logits, dim=1)
pred_dict = {class_names[i]:pred_prob.squeeze(0)[i].item() for i in range(len(class_names))}
pred_time = round(timer() - start, 5)
return pred_dict, pred_time
example_paths = list(pathlib.Path('examples').glob("*/*.jpg"))
example_list = [[str(filepath)] for filepath in random.sample(example_paths, k=6)]
title = 'Bird Species Classifier 🐦'
description = 'A [ShuffleNetV2](https://pytorch.org/vision/main/models/shufflenetv2.html) feature extractor computer vision model to classify images of [525 bird species](https://www.kaggle.com/datasets/gpiosenka/100-bird-species/).'
article = 'Made with ❤️🤗 by [me](https://www.linkedin.com/in/taufiq-dwi-purnomo/).'
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type='pil'),
outputs=[gr.Label(num_top_classes=3, label='Predictions'),
gr.Number(label="Prediction time (s)")],
description=description,
title=title,
allow_flagging='never',
examples=example_list,
article=article
)
demo.launch(debug=False)
|