Spaces:
Runtime error
Runtime error
File size: 1,938 Bytes
aa56c48 a85ef54 aa56c48 ea81160 0f7c18e d40ba1d 0f7c18e d40ba1d 0f7c18e d40ba1d 0f7c18e d40ba1d 0f7c18e ea81160 a85ef54 aa56c48 d40ba1d 0f7c18e a85ef54 aa56c48 5940022 0f7c18e aa56c48 0f7c18e aa56c48 5940022 0f7c18e aa56c48 ea81160 aa56c48 0f7c18e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
import gradio as gr
from torchvision import transforms
from PIL import ImageOps
import os
from dotenv import load_dotenv
from torch import nn
import torch.nn.functional as F
class Lenet(nn.Module):
def __init__(self, args=None):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28
self.pool1 = nn.MaxPool2d(2) # -> 6 channels, 14x14
self.conv2 = nn.Conv2d(6, 16, 5) #-> 16 images, 10x10
self.pool2 = nn.MaxPool2d(2) # -> 16 channels, 5x5
self.conv3 = nn.Conv2d(16, 120, 5) #-> 16 images, 1x1
self.fc1 = nn.Linear(120, 84)
self.fc2 = nn.Linear(84, 10)
def __call__(self, x):
xx = F.relu(self.conv1(x))
xx = F.relu(self.pool1(xx))
xx = F.relu(self.conv2(xx))
xx = F.relu(self.pool2(xx))
xx = F.relu(self.conv3(xx))
xx = xx.flatten(1)
xx = F.relu(self.fc1(xx))
return self.fc2(xx)
load_dotenv()
hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")
def load_model():
model = Lenet()
model.load_state_dict(torch.load('model.pt'))
model.eval()
return model
model = load_model()
convert_tensor = transforms.ToTensor()
def predict(img):
img = ImageOps.grayscale(img).resize((28,28))
image_tensor = convert_tensor(img).view(1, 1, 28, 28)
logits = model(image_tensor)
pred = torch.argmax(logits, dim=1)
return pred.tolist()[0]
title = "Handwritten digit recognition"
description = '<p><center>Write a single digit in the middle of the canvas</center></p>'
gr.Interface(fn=predict,
inputs=gr.Paint(type="pil", invert_colors=True),
outputs="text",
title=title,
flagging_options=["incorrect","ambiguous"],
flagging_callback=hf_writer,
description=description,
allow_flagging='manual').launch()
|