File size: 1,286 Bytes
911ccdf
2fb361b
 
911ccdf
2fb361b
 
911ccdf
2fb361b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
from model import *

from PIL import Image
import torchvision.transforms as transforms

title = "Digit Classifier"
description = (
    "Multilayer-Perceptron built for the fast.ai 'Deep Learning' course "
    "to classify handwritten digits from the MNIST dataset. "
)
inputs = gr.components.Image()
outputs = gr.components.Label()
examples = "examples"

model = torch.load("model/digit_classifier.pt", map_location=torch.device("cpu"))
labels = [str(i) for i in range(10)]

transform = transforms.Compose(
    [
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x[0]),
        transforms.Lambda(lambda x: x.unsqueeze(0)),
    ]
)


def predict_digit(img):
    img = transform(Image.fromarray(img))
    output = model(img)
    probs = torch.nn.functional.softmax(output, dim=1)
    return dict(zip(labels, map(float, probs.flatten()[:10])))


with gr.Blocks() as demo:
    with gr.Tab("Digit Prediction"):
        gr.Interface(
            fn=predict_digit,
            inputs=inputs,
            outputs=outputs,
            examples=examples,
            title=title,
            description=description,
        ).queue(default_concurrency_limit=5)

demo.launch()