File size: 2,700 Bytes
fca2efd
 
 
 
 
9d1a8a7
3647577
fca2efd
 
 
 
 
 
 
 
088dea4
 
fca2efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f7f294
367d735
 
fca2efd
 
 
9d1a8a7
088dea4
 
367d735
fca2efd
9d1a8a7
 
 
fca2efd
 
 
 
30935b4
fca2efd
 
 
 
 
 
 
 
 
 
 
 
 
 
9d1a8a7
 
 
fca2efd
 
088dea4
fca2efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d1a8a7
 
 
 
 
 
 
 
 
fca2efd
 
367d735
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import glob
import os

import numpy as np
from norfair import AbsolutePaths, FixedCamera, Paths, Tracker, Video
from norfair.camera_motion import HomographyTransformationGetter, MotionEstimator

from inference_utils import (
    YOLO,
    center,
    clean_videos,
    draw,
    euclidean_distance,
    iou,
    models_path,
    style,
    yolo_detections_to_norfair_detections,
)

DISTANCE_THRESHOLD_BBOX: float = 3.33
DISTANCE_THRESHOLD_CENTROID: int = 30
MAX_DISTANCE: int = 10000


def inference(
    input_video: str,
    model: str,
    motion_estimation: bool,
    drawing_paths: bool,
    track_points: str,
    model_threshold: str,
):
    output_path = "tmp"

    clean_videos(output_path)

    coord_transformations = None
    paths_drawer = None
    fix_paths = False
    track_points = style[track_points]
    model = YOLO(models_path[model])
    video = Video(input_path=input_video, output_path=output_path)

    if motion_estimation and drawing_paths:
        fix_paths = True

    if motion_estimation:
        transformations_getter = HomographyTransformationGetter()

        motion_estimator = MotionEstimator(
            max_points=500, min_distance=7, transformations_getter=transformations_getter
        )

    distance_function = iou if track_points == "bbox" else euclidean_distance
    distance_threshold = (
        DISTANCE_THRESHOLD_BBOX if track_points == "bbox" else DISTANCE_THRESHOLD_CENTROID
    )
    tracker = Tracker(
        distance_function=distance_function,
        distance_threshold=distance_threshold,
    )

    if drawing_paths:
        paths_drawer = Paths(center, attenuation=0.01)

    if fix_paths:
        paths_drawer = AbsolutePaths(max_history=5, thickness=2)

    for frame in video:
        yolo_detections = model(
            frame, conf_threshold=model_threshold, iou_threshold=0.45, image_size=720
        )

        mask = np.ones(frame.shape[:2], frame.dtype)

        if motion_estimation:
            coord_transformations = motion_estimator.update(frame, mask)

        detections = yolo_detections_to_norfair_detections(
            yolo_detections, track_points=track_points
        )

        tracked_objects = tracker.update(
            detections=detections, coord_transformations=coord_transformations
        )

        frame = draw(
            paths_drawer,
            track_points,
            frame,
            detections,
            tracked_objects,
            coord_transformations,
            fix_paths,
        )
        video.write(frame)

    base_file_name = input_video.split("/")[-1].split(".")[0]
    file_name = base_file_name + "_out.mp4"
    return os.path.join(output_path, file_name)