kisa-misa commited on
Commit
a44b5ff
1 Parent(s): e0008d0

Upload yolos_minimal_inference_example.py

Browse files
Files changed (1) hide show
  1. yolos_minimal_inference_example.py +107 -0
yolos_minimal_inference_example.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """YOLOS minimal inference example.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/YOLOS/YOLOS_minimal_inference_example.ipynb
8
+
9
+ ## Set-up environment
10
+
11
+ First, we install the HuggingFace Transformers library (from source for now, as the model was just added to the library and not yet included in a new PyPi release).
12
+ """
13
+
14
+ !pip install -q git+https://github.com/huggingface/transformers.git
15
+
16
+ pip install gradio
17
+
18
+ import gradio as gr
19
+ from gradio.mix import Series
20
+ from PIL import Image
21
+ import requests
22
+ from transformers import AutoFeatureExtractor, YolosForObjectDetection
23
+ import torch
24
+ import matplotlib.pyplot as plt
25
+ import cv2
26
+
27
+ import os
28
+ os.getcwd()
29
+
30
+ # colors for visualization
31
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
32
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
33
+
34
+ def plot_results(pil_img, prob, boxes, count):
35
+ plt.figure(figsize=(16,10))
36
+ plt.imshow(pil_img)
37
+ ax = plt.gca()
38
+ colors = COLORS * 100
39
+ for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
40
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
41
+ fill=False, color=c, linewidth=3))
42
+ cl = p.argmax()
43
+ text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
44
+ ax.text(xmin, ymin, text, fontsize=15,
45
+ bbox=dict(facecolor='yellow', alpha=0.5))
46
+ plt.axis('off')
47
+ if count < 10:
48
+ plt.savefig('exp2/frame0%d.png' % count)
49
+ else: plt.savefig('exp2/frame%d.png' % count)
50
+
51
+ model = YolosForObjectDetection.from_pretrained("hustvl/yolos-small")
52
+ vidcap = cv2.VideoCapture('/content/2022-08-10_ППП-стоянки_кам-3_191356 (online-video-cutter.com).mp4')
53
+ success,image = vidcap.read()
54
+ count = 0
55
+ #path = '/content/cutted'
56
+ feature_extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-small")
57
+
58
+ while success:
59
+ success,image = vidcap.read()
60
+ count += 1
61
+
62
+ if count%10 == 0:
63
+ image = Image.fromarray(image)
64
+ pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
65
+
66
+ with torch.no_grad():
67
+ outputs = model(pixel_values, output_attentions=True)
68
+
69
+ # keep only predictions of queries with 0.9+ confidence (excluding no-object class)
70
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
71
+ keep = probas.max(-1).values > 0.8
72
+
73
+ # rescale bounding boxes
74
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
75
+ postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
76
+ bboxes_scaled = postprocessed_outputs[0]['boxes']
77
+ plot_results(image, probas[keep], bboxes_scaled[keep], count)
78
+
79
+ print('Process a new frame: ', success)
80
+
81
+ """Set model and directory parameters:
82
+
83
+ Perform sliced inference on given folder:
84
+ """
85
+
86
+ image_folder = '/content/exp2'
87
+ file_list = os.listdir(image_folder)
88
+
89
+ #grab last 2 characters of the file name:
90
+ def last_2chars(x):
91
+ return(x[5:7])
92
+
93
+ srtd = sorted(file_list, key = last_2chars)
94
+
95
+ video_name = 'video.avi'
96
+
97
+ images = [img for img in srtd if img.endswith(".png")]
98
+ frame = cv2.imread(os.path.join(image_folder, images[0]))
99
+ height, width, layers = frame.shape
100
+
101
+ video = cv2.VideoWriter(video_name, 0, 5, (width,height))
102
+
103
+ for image in images:
104
+ video.write(cv2.imread(os.path.join(image_folder, image)))
105
+
106
+ cv2.destroyAllWindows()
107
+ video.release()