|
import json
|
|
import logging
|
|
import os
|
|
import pathlib
|
|
import re
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple, Union, Dict, Any
|
|
import torch
|
|
|
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
|
_MODEL_CONFIGS = {}
|
|
|
|
|
|
def _natural_key(string_):
|
|
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
|
|
|
|
|
def _rescan_model_configs():
|
|
global _MODEL_CONFIGS
|
|
|
|
config_ext = (".json",)
|
|
config_files = []
|
|
for config_path in _MODEL_CONFIG_PATHS:
|
|
if config_path.is_file() and config_path.suffix in config_ext:
|
|
config_files.append(config_path)
|
|
elif config_path.is_dir():
|
|
for ext in config_ext:
|
|
config_files.extend(config_path.glob(f"*{ext}"))
|
|
|
|
for cf in config_files:
|
|
with open(cf, "r", encoding="utf8") as f:
|
|
model_cfg = json.load(f)
|
|
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")):
|
|
_MODEL_CONFIGS[cf.stem] = model_cfg
|
|
|
|
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
|
|
|
|
|
|
_rescan_model_configs()
|
|
|
|
|
|
def list_models():
|
|
"""enumerate available model architectures based on config files"""
|
|
return list(_MODEL_CONFIGS.keys())
|
|
|
|
|
|
def add_model_config(path):
|
|
"""add model config path or file and update registry"""
|
|
if not isinstance(path, Path):
|
|
path = Path(path)
|
|
_MODEL_CONFIG_PATHS.append(path)
|
|
_rescan_model_configs()
|
|
|
|
|
|
def get_model_config(model_name):
|
|
if model_name in _MODEL_CONFIGS:
|
|
return deepcopy(_MODEL_CONFIGS[model_name])
|
|
else:
|
|
return None
|
|
|