import json import logging import os import pathlib import re from copy import deepcopy from pathlib import Path from typing import Optional, Tuple, Union, Dict, Any import torch _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS config_ext = (".json",) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: config_files.extend(config_path.glob(f"*{ext}")) for cf in config_files: with open(cf, "r", encoding="utf8") as f: model_cfg = json.load(f) if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): _MODEL_CONFIGS[cf.stem] = model_cfg _MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) _rescan_model_configs() # initial populate of model config registry def list_models(): """enumerate available model architectures based on config files""" return list(_MODEL_CONFIGS.keys()) def add_model_config(path): """add model config path or file and update registry""" if not isinstance(path, Path): path = Path(path) _MODEL_CONFIG_PATHS.append(path) _rescan_model_configs() def get_model_config(model_name): if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None