File size: 4,101 Bytes
595be82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os.path
from functools import lru_cache
from typing import List, Tuple

import gradio as gr
from hbutils.color import rnd_colors
from hfutils.operate import get_hf_fs
from hfutils.utils import hf_fs_path, parse_hf_fs_path
from imgutils.data import ImageTyping


class ObjectDetection:
    @lru_cache()
    def get_default_model(self) -> str:
        return self._get_default_model()

    def _get_default_model(self) -> str:
        raise NotImplementedError

    @lru_cache()
    def list_models(self) -> List[str]:
        return self._list_models()

    def _list_models(self) -> List[str]:
        raise NotImplementedError

    @lru_cache()
    def get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]:
        return self._get_default_iou_and_score(model_name)

    def _get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]:
        raise NotImplementedError

    @lru_cache()
    def get_labels(self, model_name: str) -> List[str]:
        return self._get_labels(model_name)

    def _get_labels(self, model_name: str) -> List[str]:
        raise NotImplementedError

    def detect(self, image: ImageTyping, model_name: str,
               iou_threshold: float = 0.7, score_threshold: float = 0.25) \
            -> List[Tuple[Tuple[float, float, float, float], str, float]]:
        raise NotImplementedError

    def _gr_detect(self, image: ImageTyping, model_name: str,
                   iou_threshold: float = 0.7, score_threshold: float = 0.25) \
            -> gr.AnnotatedImage:
        labels = self.get_labels(model_name=model_name)
        _colors = list(map(str, rnd_colors(len(labels))))
        _color_map = dict(zip(labels, _colors))
        return gr.AnnotatedImage(
            value=(image, [
                (bbox, label) for bbox, label, _ in
                self.detect(image, model_name, iou_threshold, score_threshold)
            ]),
            color_map=_color_map,
            label='Labeled',
        )

    def make_ui(self):
        with gr.Row():
            with gr.Column():
                default_model_name = self.get_default_model()
                model_list = self.list_models()
                gr_input_image = gr.Image(type='pil', label='Original Image')
                gr_model = gr.Dropdown(model_list, value=default_model_name, label='Model')
                with gr.Row():
                    iou, score = self.get_default_iou_and_score(default_model_name)
                    gr_iou_threshold = gr.Slider(0.0, 1.0, iou, label='IOU Threshold')
                    gr_score_threshold = gr.Slider(0.0, 1.0, score, label='Score Threshold')

                gr_submit = gr.Button(value='Submit', variant='primary')

            with gr.Column():
                gr_output_image = gr.AnnotatedImage(label="Labeled")

            gr_submit.click(
                self._gr_detect,
                inputs=[
                    gr_input_image,
                    gr_model,
                    gr_iou_threshold,
                    gr_score_threshold,
                ],
                outputs=[gr_output_image],
            )


class DeepGHSObjectDetection(ObjectDetection):
    def __init__(self, repo_id: str):
        self._repo_id = repo_id

    def _get_default_model(self) -> str:
        raise NotImplementedError

    def _list_models(self) -> List[str]:
        hf_fs = get_hf_fs()
        return [
            os.path.dirname(parse_hf_fs_path(path).filename)
            for path in hf_fs.glob(hf_fs_path(
                repo_id=self._repo_id,
                repo_type='model',
                filename='*/model.onnx'
            ))
        ]

    def _get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]:
        raise NotImplementedError

    def _get_labels(self, model_name: str) -> List[str]:
        raise NotImplementedError

    def detect(self, image: ImageTyping, model_name: str,
               iou_threshold: float = 0.7, score_threshold: float = 0.25) \
            -> List[Tuple[Tuple[float, float, float, float], str, float]]:
        raise NotImplementedError