Diego Carpintero commited on
Commit
2fb361b
1 Parent(s): 31cd6a1

implement gradio app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -4
app.py CHANGED
@@ -1,7 +1,49 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from model import *
4
 
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
 
8
+ title = "Digit Classifier"
9
+ description = (
10
+ "Multilayer-Perceptron built for the fast.ai 'Deep Learning' course "
11
+ "to classify handwritten digits from the MNIST dataset. "
12
+ )
13
+ inputs = gr.components.Image()
14
+ outputs = gr.components.Label()
15
+ examples = "examples"
16
+
17
+ model = torch.load("model/digit_classifier.pt", map_location=torch.device("cpu"))
18
+ labels = [str(i) for i in range(10)]
19
+
20
+ transform = transforms.Compose(
21
+ [
22
+ transforms.Resize((28, 28)),
23
+ transforms.Grayscale(),
24
+ transforms.ToTensor(),
25
+ transforms.Lambda(lambda x: x[0]),
26
+ transforms.Lambda(lambda x: x.unsqueeze(0)),
27
+ ]
28
+ )
29
+
30
+
31
+ def predict_digit(img):
32
+ img = transform(Image.fromarray(img))
33
+ output = model(img)
34
+ probs = torch.nn.functional.softmax(output, dim=1)
35
+ return dict(zip(labels, map(float, probs.flatten()[:10])))
36
+
37
+
38
+ with gr.Blocks() as demo:
39
+ with gr.Tab("Digit Prediction"):
40
+ gr.Interface(
41
+ fn=predict_digit,
42
+ inputs=inputs,
43
+ outputs=outputs,
44
+ examples=examples,
45
+ title=title,
46
+ description=description,
47
+ ).queue(default_concurrency_limit=5)
48
+
49
+ demo.launch()