edge_vlm / mobileclip /__init__.py
irotem98's picture
moondream_model_state_dict.pt
495fe55
raw
history blame contribute delete
No virus
2.84 kB
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import os
import json
from typing import Optional, Union, Tuple, Any
import torch
import torch.nn as nn
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Resize,
ToTensor,
)
from mobileclip.clip import CLIP
from mobileclip.modules.text.tokenizer import (
ClipTokenizer,
)
from mobileclip.modules.common.mobileone import reparameterize_model
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
reparameterize: Optional[bool] = True,
device: Union[str, torch.device] = "cpu",
) -> Tuple[nn.Module, Any, Any]:
"""
Method to instantiate model and pre-processing transforms necessary for inference.
Args:
model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
pretrained: Location of pretrained checkpoint.
reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
device: Device identifier for model placement.
Returns:
Tuple of instantiated model, and preprocessing transforms for inference.
"""
# Config files
root_dir = os.path.dirname(os.path.abspath(__file__))
configs_dir = os.path.join(root_dir, "configs")
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
# Get config from yaml file
if not os.path.exists(model_cfg_file):
raise ValueError(f"Unsupported model name: {model_name}")
model_cfg = json.load(open(model_cfg_file, "r"))
# Build preprocessing transforms for inference
resolution = model_cfg["image_cfg"]["image_size"]
resize_size = resolution
centercrop_size = resolution
aug_list = [
Resize(
resize_size,
interpolation=InterpolationMode.BILINEAR,
),
CenterCrop(centercrop_size),
ToTensor(),
]
preprocess = Compose(aug_list)
# Build model
model = CLIP(cfg=model_cfg)
model.to(device)
model.eval()
# Load checkpoint
if pretrained is not None:
chkpt = torch.load(pretrained)
model.load_state_dict(chkpt)
# Reparameterize model for inference (if specified)
if reparameterize:
model = reparameterize_model(model)
return model, None, preprocess
def get_tokenizer(model_name: str) -> nn.Module:
# Config files
root_dir = os.path.dirname(os.path.abspath(__file__))
configs_dir = os.path.join(root_dir, "configs")
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
# Get config from yaml file
model_cfg = json.load(open(model_cfg_file, "r"))
# Build tokenizer
text_tokenizer = ClipTokenizer(model_cfg)
return text_tokenizer