narugo1992 commited on
Commit
582519c
1 Parent(s): 2023a9f

dev(narugo): add monochrome

Browse files
Files changed (4) hide show
  1. app.py +18 -0
  2. cls.py +21 -17
  3. monochrome.py +42 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import gradio as gr
4
 
5
  from cls import _CLS_MODELS, _DEFAULT_CLS_MODEL, _gr_classification
 
6
 
7
  if __name__ == '__main__':
8
  with gr.Blocks() as demo:
@@ -24,4 +25,21 @@ if __name__ == '__main__':
24
  outputs=[gr_cls_output],
25
  )
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  demo.queue(os.cpu_count()).launch()
 
3
  import gradio as gr
4
 
5
  from cls import _CLS_MODELS, _DEFAULT_CLS_MODEL, _gr_classification
6
+ from monochrome import _gr_monochrome, _DEFAULT_MONO_MODEL, _MONO_MODELS
7
 
8
  if __name__ == '__main__':
9
  with gr.Blocks() as demo:
 
25
  outputs=[gr_cls_output],
26
  )
27
 
28
+ with gr.Tab('Monochrome'):
29
+ with gr.Row():
30
+ with gr.Column():
31
+ gr_mono_input_image = gr.Image(type='pil', label='Original Image')
32
+ gr_mono_model = gr.Dropdown(_MONO_MODELS, value=_DEFAULT_MONO_MODEL, label='Model')
33
+ gr_mono_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
34
+ gr_mono_submit = gr.Button(value='Submit', variant='primary')
35
+
36
+ with gr.Column():
37
+ gr_mono_output = gr.Label(label='Classes')
38
+
39
+ gr_mono_submit.click(
40
+ _gr_monochrome,
41
+ inputs=[gr_mono_input_image, gr_mono_model, gr_mono_infer_size],
42
+ outputs=[gr_mono_output],
43
+ )
44
+
45
  demo.queue(os.cpu_count()).launch()
cls.py CHANGED
@@ -1,31 +1,34 @@
 
 
1
  from functools import lru_cache
2
- from typing import Mapping
3
 
4
- from huggingface_hub import hf_hub_download
5
  from imgutils.data import ImageTyping, load_image
 
6
 
7
  from onnx_ import _open_onnx_model
8
  from preprocess import _img_encode
9
 
10
- _LABELS = ['3d', 'bangumi', 'comic', 'illustration']
11
- _CLS_MODELS = [
12
- 'caformer_s36',
13
- 'caformer_s36_plus',
14
- 'mobilenetv3',
15
- 'mobilenetv3_dist',
16
- 'mobilenetv3_sce',
17
- 'mobilenetv3_sce_dist',
18
- 'mobilevitv2_150',
19
- ]
20
  _DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'
21
 
22
 
23
  @lru_cache()
24
  def _open_anime_classify_model(model_name):
25
- return _open_onnx_model(hf_hub_download(
26
- f'deepghs/anime_classification',
27
- f'{model_name}/model.onnx',
28
- ))
 
 
 
29
 
30
 
31
  def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
@@ -33,5 +36,6 @@ def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping
33
  input_ = _img_encode(image, size=(size, size))[None, ...]
34
  output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})
35
 
36
- values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
 
37
  return values
 
1
+ import json
2
+ import os
3
  from functools import lru_cache
4
+ from typing import Mapping, List
5
 
6
+ from huggingface_hub import hf_hub_download, HfFileSystem
7
  from imgutils.data import ImageTyping, load_image
8
+ from natsort import natsorted
9
 
10
  from onnx_ import _open_onnx_model
11
  from preprocess import _img_encode
12
 
13
+ hfs = HfFileSystem()
14
+
15
+ _REPO = 'deepghs/anime_classification'
16
+ _CLS_MODELS = natsorted([
17
+ os.path.dirname(os.path.relpath(file, _REPO))
18
+ for file in hfs.glob(f'{_REPO}/*/model.onnx')
19
+ ])
 
 
 
20
  _DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'
21
 
22
 
23
  @lru_cache()
24
  def _open_anime_classify_model(model_name):
25
+ return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
26
+
27
+
28
+ @lru_cache()
29
+ def _get_tags(model_name) -> List[str]:
30
+ with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
31
+ return json.load(f)['labels']
32
 
33
 
34
  def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
 
36
  input_ = _img_encode(image, size=(size, size))[None, ...]
37
  output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})
38
 
39
+ labels = _get_tags(model_name)
40
+ values = dict(zip(labels, map(lambda x: x.item(), output[0])))
41
  return values
monochrome.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Mapping, List
5
+
6
+ from huggingface_hub import HfFileSystem
7
+ from huggingface_hub import hf_hub_download
8
+ from imgutils.data import ImageTyping, load_image
9
+ from natsort import natsorted
10
+
11
+ from onnx_ import _open_onnx_model
12
+ from preprocess import _img_encode
13
+
14
+ hfs = HfFileSystem()
15
+
16
+ _REPO = 'deepghs/monochrome_detect'
17
+ _MONO_MODELS = natsorted([
18
+ os.path.dirname(os.path.relpath(file, _REPO))
19
+ for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
+ ])
21
+ _DEFAULT_MONO_MODEL = 'mobilenetv3_large_100_dist'
22
+
23
+
24
+ @lru_cache()
25
+ def _open_anime_monochrome_model(model_name):
26
+ return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
+
28
+
29
+ @lru_cache()
30
+ def _get_tags(model_name) -> List[str]:
31
+ with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
+ return json.load(f)['labels']
33
+
34
+
35
+ def _gr_monochrome(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
+ image = load_image(image, mode='RGB')
37
+ input_ = _img_encode(image, size=(size, size))[None, ...]
38
+ output, = _open_anime_monochrome_model(model_name).run(['output'], {'input': input_})
39
+
40
+ labels = _get_tags(model_name)
41
+ values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
+ return values
requirements.txt CHANGED
@@ -2,10 +2,11 @@ gradio==3.18.0
2
  numpy
3
  pillow
4
  onnxruntime
5
- huggingface_hub
6
  scikit-image
7
  pandas
8
  opencv-python>=4.6.0
9
  hbutils>=0.9.0
10
  dghs-imgutils>=0.1.0
11
  httpx==0.23.0
 
 
2
  numpy
3
  pillow
4
  onnxruntime
5
+ huggingface_hub>=0.14.0
6
  scikit-image
7
  pandas
8
  opencv-python>=4.6.0
9
  hbutils>=0.9.0
10
  dghs-imgutils>=0.1.0
11
  httpx==0.23.0
12
+ natsort