okeowo1014 commited on
Commit
df2a532
1 Parent(s): 41b11ad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from flask import Flask, render_template, request, jsonify
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+
9
+ num_classes = 10
10
+
11
+ # Class definition for the model (same as in your code)
12
+ class FingerprintRecognitionModel(nn.Module):
13
+ def __init__(self, num_classes):
14
+ super(FingerprintRecognitionModel, self).__init__()
15
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
16
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
17
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
18
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
19
+ self.fc1 = nn.Linear(128 * 28 * 28, 256)
20
+ self.fc2 = nn.Linear(256, num_classes)
21
+
22
+ def forward(self, x):
23
+ x = self.pool(F.relu(self.conv1(x)))
24
+ x = self.pool(F.relu(self.conv2(x)))
25
+ x = self.pool(F.relu(self.conv3(x)))
26
+ x = x.view(-1, 128 * 28 * 28)
27
+ x = F.relu(self.fc1(x))
28
+ x = F.softmax(self.fc2(x), dim=1)
29
+ return x
30
+
31
+ app = Flask(__name__)
32
+
33
+ # Load the model
34
+ model_path = 'fingerprint_recognition_model_bs32_lr0.001_opt_Adam.pt'
35
+ model = FingerprintRecognitionModel(num_classes)
36
+ model.load_state_dict(torch.load(model_path))
37
+ model.eval()
38
+
39
+ def preprocess_image(image_bytes):
40
+ # Convert bytes to PIL Image
41
+ image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale
42
+
43
+ # Resize to 224x224
44
+ img_resized = image.resize((224, 224))
45
+
46
+ transform = transforms.Compose([
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.5,), (0.5,))
49
+ ])
50
+
51
+ # Apply transforms and add batch dimension
52
+ img_tensor = transform(img_resized).unsqueeze(0)
53
+
54
+ return img_tensor
55
+
56
+ def predict_class(image_bytes):
57
+ img_tensor = preprocess_image(image_bytes)
58
+ with torch.no_grad():
59
+ outputs = model(img_tensor)
60
+ _, predicted = torch.max(outputs.data, 1)
61
+ predicted_class = int(predicted.item())
62
+ return predicted_class
63
+
64
+ @app.route('/', methods=['GET', 'POST'])
65
+ def index():
66
+ if request.method == 'POST':
67
+ file = request.files['file']
68
+ if file:
69
+ contents = file.read()
70
+ predicted_class = predict_class(contents)
71
+ class_labels = {0:'left_index_fingers',1:'left_little_fingers',2:'left_middle_fingers',3: 'left_ring_fingers', 4:'left_thumb_fingers',5:'right_index_fingers',6:'right_little_fingers',7:'right_middle_fingers',8:'right_ring_fingers',9: 'right_thumb_fingers'}
72
+ return jsonify({'predicted_class': predicted_class, 'class_label': class_labels[predicted_class]})
73
+ return render_template('index.html')
74
+
75
+ if __name__ == '__main__':
76
+ app.run(debug=True)