File size: 1,938 Bytes
aa56c48
a85ef54
aa56c48
 
ea81160
 
 
0f7c18e
 
 
d40ba1d
0f7c18e
 
 
d40ba1d
 
 
 
 
 
0f7c18e
 
 
d40ba1d
0f7c18e
d40ba1d
 
0f7c18e
 
 
 
ea81160
 
 
a85ef54
aa56c48
d40ba1d
0f7c18e
 
 
a85ef54
aa56c48
 
 
 
5940022
0f7c18e
 
 
 
aa56c48
0f7c18e
 
aa56c48
 
5940022
0f7c18e
aa56c48
ea81160
 
aa56c48
0f7c18e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import gradio as gr
from torchvision import transforms
from PIL import ImageOps
import os
from dotenv import load_dotenv

from torch import nn
import torch.nn.functional as F

class Lenet(nn.Module):
    def __init__(self, args=None):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28
        self.pool1 = nn.MaxPool2d(2) # -> 6 channels, 14x14
        self.conv2 = nn.Conv2d(6, 16, 5) #-> 16 images, 10x10
        self.pool2 = nn.MaxPool2d(2) # -> 16 channels,  5x5
        self.conv3 = nn.Conv2d(16, 120, 5) #-> 16 images, 1x1
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)

    def __call__(self, x):
        xx = F.relu(self.conv1(x))
        xx = F.relu(self.pool1(xx))
        xx = F.relu(self.conv2(xx))
        xx = F.relu(self.pool2(xx))
        xx = F.relu(self.conv3(xx))
        xx = xx.flatten(1)
        xx = F.relu(self.fc1(xx))
        return self.fc2(xx)

load_dotenv()

hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")

def load_model():
    model = Lenet()
    model.load_state_dict(torch.load('model.pt'))
    model.eval()
    return model

model = load_model()
convert_tensor = transforms.ToTensor()

def predict(img):
    img =  ImageOps.grayscale(img).resize((28,28))
    image_tensor = convert_tensor(img).view(1, 1, 28, 28)
    logits = model(image_tensor)
    pred = torch.argmax(logits, dim=1)
    return pred.tolist()[0]

title = "Handwritten digit recognition"
description = '<p><center>Write a single digit in the middle of the canvas</center></p>'

gr.Interface(fn=predict, 
             inputs=gr.Paint(type="pil", invert_colors=True),
             outputs="text",
             title=title,
             flagging_options=["incorrect","ambiguous"],
             flagging_callback=hf_writer,
             description=description,
             allow_flagging='manual').launch()