import json import os from functools import lru_cache from typing import Mapping, List from huggingface_hub import hf_hub_download, HfFileSystem from imgutils.data import ImageTyping, load_image from natsort import natsorted from onnx_ import _open_onnx_model from preprocess import _img_encode hfs = HfFileSystem() _REPO = 'deepghs/anime_classification' _CLS_MODELS = natsorted([ os.path.dirname(os.path.relpath(file, _REPO)) for file in hfs.glob(f'{_REPO}/*/model.onnx') ]) _DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist' @lru_cache() def _open_anime_classify_model(model_name): return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx')) @lru_cache() def _get_tags(model_name) -> List[str]: with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f: return json.load(f)['labels'] def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: image = load_image(image, mode='RGB') input_ = _img_encode(image, size=(size, size))[None, ...] output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) labels = _get_tags(model_name) values = dict(zip(labels, map(lambda x: x.item(), output[0]))) return values